예제 #1
0
    def get_heuristic(self, obs):
        if self.env.check_goal(obs):
            return 0
        cell = obs['observation'].copy()
        goal_cell = obs['desired_goal'].copy()

        return compute_heuristic(cell, goal_cell, self.args.goal_threshold)
예제 #2
0
파일: worker.py 프로젝트: vvanirudh/CMAXPP
 def get_qvalue(self, obs, ac):
     if self.env.check_goal(obs):
         return 0
     cell = obs['observation'].copy()
     goal_cell = obs['desired_goal'].copy()
     value = compute_heuristic(cell, goal_cell, self.args.goal_threshold)
     features = compute_features(cell, goal_cell, self.env.carry_cell,
                                 self.env.obstacle_cell_aa,
                                 self.env.obstacle_cell_bb,
                                 self.args.grid_size,
                                 self.env._grid_to_continuous)
     features_norm = self.feature_normalizer_q.normalize(features)
     ac_idx = self.actions_index[ac]
     residual_state_action_value = get_state_action_value_residual(
         features_norm, ac_idx, self.state_action_value_residual)
     return value + residual_state_action_value
예제 #3
0
파일: worker.py 프로젝트: vvanirudh/CMAXPP
    def get_state_value(self, obs, inflated=False):
        if self.env.check_goal(obs):
            return 0
        cell = obs['observation'].copy()
        goal_cell = obs['desired_goal'].copy()
        value = compute_heuristic(cell, goal_cell, self.args.goal_threshold)
        features = compute_features(cell, goal_cell, self.env.carry_cell,
                                    self.env.obstacle_cell_aa,
                                    self.env.obstacle_cell_bb,
                                    self.args.grid_size,
                                    self.env._grid_to_continuous)
        features_norm = self.feature_normalizer.normalize(features)

        # Use inflated if need be
        if inflated:
            state_value_residual = self.inflated_state_value_residual
        else:
            state_value_residual = self.state_value_residual

        residual_value = get_state_value_residual(features_norm,
                                                  state_value_residual)
        return value + residual_value
    def learn_online_in_real_world(self):
        current_observation = copy.deepcopy(
            self.env.get_current_observation(goal=True))

        total_n_steps = 0
        current_n_steps = 0
        max_attempts = self.args.max_attempts
        n_attempts = 0
        n_steps = []
        start = time.time()
        while True:
            print('-------------')
            print('Current cell', current_observation['observation'])

            ac, _ = self.controller.act(copy.deepcopy(current_observation))
            if self.rng_exploration.rand() < self.args.epsilon:
                ac_idx = self.rng.randint(len(self.actions))
                ac = self.actions[ac_idx]
            print('Action', ac)
            print('Current state-action value',
                  self.get_qvalue(current_observation, ac))
            print(
                'Current cell heuristic',
                compute_heuristic(current_observation['observation'],
                                  current_observation['desired_goal'],
                                  self.args.goal_threshold))

            # Step in the environment
            next_observation, cost = self.env.step(ac)
            print('True next cell', next_observation['observation'])
            total_n_steps += 1
            current_n_steps += 1

            self.add_to_transition_buffer(n_attempts, current_observation, ac,
                                          cost, next_observation)

            if (current_n_steps + 1) % self.args.qlearning_update_freq == 0:
                for _ in range(self.args.num_updates *
                               self.args.qlearning_update_freq):
                    self.update_state_action_value_residual_qlearning()
                    self.update_target_networks()

            # Check goal
            check_goal = self.env.check_goal(next_observation)
            max_timesteps = current_n_steps >= self.args.max_timesteps
            if check_goal or max_timesteps:
                n_steps.append(current_n_steps)
                current_n_steps = 0
                if check_goal:
                    print_green('Reached goal in ' + str(n_steps) + ' steps')
                if max_timesteps:
                    print_fail('Maxed out number of steps')
                    break
                print('======================================================')
                self.env.reset(goal=True)
                current_observation = copy.deepcopy(
                    self.env.get_current_observation(goal=True))
                n_attempts += 1
                if n_attempts == max_attempts:
                    break
                continue
                # break

            # Update current observation
            current_observation = copy.deepcopy(next_observation)

        end = time.time()
        print_green('Finished in time ' + str(end - start) + ' secs')
        return n_steps
예제 #5
0
    def learn_online_in_real_world(self):
        current_observation = copy.deepcopy(
            self.env.get_current_observation(goal=True))

        total_n_steps = 0
        current_n_steps = 0
        max_attempts = self.args.max_attempts
        n_attempts = 0
        n_steps = []
        start = time.time()
        while True:
            print('-------------')
            print('Current cell', current_observation['observation'])

            ac, info = self.controller.act(copy.deepcopy(current_observation))
            print('Action', ac)
            ac_inflated, info_inflated = self.controller_inflated.act(
                copy.deepcopy(current_observation))
            print('Inflated action', ac_inflated)

            print('Current cell inflated value', info_inflated['best_node_f'])
            print('Current cell non-inflated value', info['best_node_f'])
            print(
                'Current cell heuristic',
                compute_heuristic(current_observation['observation'],
                                  current_observation['desired_goal'],
                                  self.args.goal_threshold))

            if (info_inflated['best_node_f'] <=
                (1 + self.alpha) * info['best_node_f']) and (ac_inflated
                                                             is not None):
                # CMAX action
                executed_inflated_action = True
                print_blue('Following inflated cost-to-go')
                ac_chosen = ac_inflated
            else:
                # CMAX++ action
                executed_inflated_action = False
                print_green('Following non-inflated cost-to-go')
                ac_chosen = ac

            if (info['best_node_f'] >= 100
                    and info_inflated['best_node_f'] >= 100):
                # ADAPTIVE CMAXPP is stuck
                print_fail("ADAPTIVE CMAXPP is stuck")
                n_steps.append(self.args.max_timesteps)
                current_n_steps = 0
                self.env.reset(goal=True)
                current_observation = copy.deepcopy(
                    self.env.get_current_observation(goal=True))
                n_attempts += 1
                break
            elif (not executed_inflated_action) and (info['best_node_f'] >=
                                                     100):
                # CMAXPP is stuck
                print_fail("CMAXPP is stuck")
                n_steps.append(self.args.max_timesteps)
                current_n_steps = 0
                self.env.reset(goal=True)
                current_observation = copy.deepcopy(
                    self.env.get_current_observation(goal=True))
                n_attempts += 1
                break
            elif executed_inflated_action and (info_inflated['best_node_f'] >=
                                               100):
                # CMAX is stuck
                print_fail("CMAX is stuck")
                n_steps.append(self.args.max_timesteps)
                current_n_steps = 0
                self.env.reset(goal=True)
                current_observation = copy.deepcopy(
                    self.env.get_current_observation(goal=True))
                n_attempts += 1
                break

            if ac_chosen is not None:
                # Step in the environment
                next_observation, cost = self.env.step(ac_chosen)
                print('True next cell', next_observation['observation'])

                # Add to buffers
                # self.add_to_state_buffer(current_observation)
                self.add_to_transition_buffer(n_attempts, current_observation,
                                              ac_chosen, cost,
                                              next_observation)

                if not executed_inflated_action:
                    # CMAXPP
                    next_sim_observation = info['successor_obs']
                    if next_sim_observation is None:
                        # The next successor is unknown to the model
                        # Already discovered discrepancy
                        print_warning(
                            'Executed a previously known to be incorrect transition'
                        )
                        # self.add_to_transition_buffer(n_attempts, current_observation,
                        #                               ac_chosen, cost, next_observation)
                    else:
                        print('Predicted next cell',
                              next_sim_observation['observation'])
                        # Is there a discrepancy?
                        discrepancy_found = self.check_discrepancy(
                            current_observation, ac_chosen, next_observation,
                            next_sim_observation)
                        if discrepancy_found:
                            print_warning('Discrepancy!')
                            # self.add_to_transition_buffer(n_attempts, current_observation,
                            #                               ac_chosen, cost, next_observation)
                            if np.array_equal(
                                    current_observation['observation'],
                                    next_observation['observation']):
                                print_warning('BLOCKING DISCREPANCY')
                            else:
                                print_warning('NON-BLOCKING DISCREPANCY')

                    # self.rollout_in_model(current_observation, inflated=False)
                    # for _ in range(self.args.num_updates):
                    #     for _ in range(self.args.num_updates_q):
                    #         self.update_state_action_value_residual()
                    #     self.update_state_value_residual(inflated=False)
                    #     self.update_target_networks(inflated=False)

                if executed_inflated_action:
                    # CMAX
                    next_sim_observation = info_inflated['successor_obs']
                    print('Predicted next cell',
                          next_sim_observation['observation'])
                    discrepancy_found = self.check_discrepancy(
                        current_observation, ac_chosen, next_observation,
                        next_sim_observation)
                    if discrepancy_found:
                        print_warning('Discrepancy!')
                        # self.add_to_transition_buffer(n_attempts, current_observation,
                        #                               ac_chosen, cost, next_observation)
                        if np.array_equal(current_observation['observation'],
                                          next_observation['observation']):
                            print_warning('BLOCKING DISCREPANCY')
                        else:
                            print_warning('NON-BLOCKING DISCREPANCY')

                    # self.rollout_in_model(current_observation, inflated=True)
                    # for _ in range(self.args.num_updates):
                    #     self.update_state_value_residual(inflated=True)
                    #     self.update_target_networks(inflated=True)
            else:
                next_observation = copy.deepcopy(current_observation)

            self.rollout_in_model(current_observation, inflated=False)
            self.rollout_in_model(current_observation, inflated=True)
            for _ in range(self.args.num_updates):
                self.update_state_action_value_residual()
                self.update_state_value_residual(inflated=False)
                self.update_state_value_residual(inflated=True)
                self.update_target_networks()

            total_n_steps += 1
            current_n_steps += 1

            # Check goal
            check_goal = self.env.check_goal(next_observation)
            max_timesteps = current_n_steps >= self.args.max_timesteps
            if check_goal or max_timesteps:
                n_steps.append(current_n_steps)
                current_n_steps = 0
                if check_goal:
                    # Decrease alpha
                    self.alpha = self.alpha * 0.5
                    print_blue('Changed alpha to ' + str(self.alpha))
                    print_green('Reached goal in ' + str(n_steps[-1]) +
                                ' steps')
                    print_green('Steps so far ' + str(n_steps))
                if max_timesteps:
                    print_fail('Maxed out number of steps')
                    break
                print('======================================================')
                self.env.reset(goal=True)
                current_observation = copy.deepcopy(
                    self.env.get_current_observation(goal=True))
                n_attempts += 1
                if n_attempts == max_attempts:
                    break
                continue
                # break

            # Update current observation
            current_observation = copy.deepcopy(next_observation)

        end = time.time()
        print_green('Finished in time ' + str(end - start) + ' secs')
        return n_steps
예제 #6
0
    def learn_online_in_real_world(self):
        current_observation = copy.deepcopy(
            self.env.get_current_observation(goal=True))

        total_n_steps = 0
        current_n_steps = 0
        max_attempts = self.args.max_attempts
        n_attempts = 0
        n_steps = []
        start = time.time()
        while True:
            print('-------------')
            print('Current cell', current_observation['observation'])

            ac, info = self.controller.act(copy.deepcopy(current_observation))
            print('Action', ac)
            print('Current cell value', info['best_node_f'])
            print(
                'Current cell heuristic',
                compute_heuristic(current_observation['observation'],
                                  current_observation['desired_goal'],
                                  self.args.goal_threshold))

            if ac is not None:
                # Step in the environment
                next_observation, cost = self.env.step(ac)
                print('True next cell', next_observation['observation'])

                # Step in the model
                self.controller.model.set_observation(
                    copy.deepcopy(current_observation))
                next_observation_sim, _ = self.controller.model.step(ac)

                # Add to buffers
                self.add_to_dynamics_transition_buffer(current_observation, ac,
                                                       cost, next_observation,
                                                       next_observation_sim)

                next_sim_observation = info['successor_obs']
                print('Predicted next cell',
                      next_sim_observation['observation'])
            else:
                next_observation = copy.deepcopy(current_observation)

            self.rollout_in_model(current_observation)
            self.update_knn_dynamics_residual()
            for _ in range(self.args.num_updates):
                self.update_state_value_residual()
                self.update_target_networks()

            total_n_steps += 1
            current_n_steps += 1

            # Check goal
            check_goal = self.env.check_goal(next_observation)
            max_timesteps = current_n_steps >= self.args.max_timesteps
            if check_goal or max_timesteps:
                n_steps.append(current_n_steps)
                current_n_steps = 0
                if check_goal:
                    print_green('Reached goal')
                    print_green('Reached goal in ' + str(n_steps[-1]) +
                                ' steps')
                if max_timesteps:
                    print_fail('Maxed out number of steps')
                    break
                print('======================================================')
                self.env.reset(goal=True)
                current_observation = copy.deepcopy(
                    self.env.get_current_observation(goal=True))
                n_attempts += 1
                if n_attempts == max_attempts:
                    break
                continue
                # break

            # Update current observation
            current_observation = copy.deepcopy(next_observation)

        end = time.time()
        print_green('Finished in time ' + str(end - start) + ' secs')
        return n_steps
예제 #7
0
    def learn_online_in_real_world(self):
        current_observation = copy.deepcopy(
            self.env.get_current_observation(goal=True))

        total_n_steps = 0
        current_n_steps = 0
        max_attempts = self.args.max_attempts
        n_attempts = 0
        n_steps = []
        start = time.time()
        while True:
            print('-------------')
            print('Current cell', current_observation['observation'])

            ac, info = self.controller.act(copy.deepcopy(current_observation))
            print('Action', ac)
            print('Current cell value', info['best_node_f'])
            print(
                'Current cell heuristic',
                compute_heuristic(current_observation['observation'],
                                  current_observation['desired_goal'],
                                  self.args.goal_threshold))

            if ac is None or info['best_node_f'] >= 100:
                # CMAX is stuck
                print_fail("CMAX is stuck")
                n_steps.append(self.args.max_timesteps)
                current_n_steps = 0
                self.env.reset(goal=True)
                current_observation = copy.deepcopy(
                    self.env.get_current_observation(goal=True))
                n_attempts += 1
                break

            # Step in the environment
            next_observation, cost = self.env.step(ac)
            print('True next cell', next_observation['observation'])
            total_n_steps += 1
            current_n_steps += 1

            # Add to buffers
            # self.add_to_state_buffer(current_observation)

            next_sim_observation = info['successor_obs']
            print('Predicted next cell', next_sim_observation['observation'])

            # Is there a discrepancy?
            discrepancy_found = self.check_discrepancy(current_observation, ac,
                                                       next_observation,
                                                       next_sim_observation)
            if discrepancy_found:
                print_warning('Discrepancy!')
                if np.array_equal(current_observation['observation'],
                                  next_observation['observation']):
                    print_warning('BLOCKING DISCREPANCY')
                else:
                    print_warning('NON-BLOCKING DISCREPANCY')
                # Replan
                # _, info = self.controller.act(
                #     copy.deepcopy(current_observation))

            self.rollout_in_model(current_observation)
            for _ in range(self.args.num_updates):
                # TODO: Should I pass in the current_observation?
                self.update_state_value_residual()
                self.update_target_networks()

            # Check goal
            check_goal = self.env.check_goal(next_observation)
            max_timesteps = current_n_steps >= self.args.max_timesteps
            if check_goal or max_timesteps:
                n_steps.append(current_n_steps)
                current_n_steps = 0
                if check_goal:
                    print_green('Reached goal')
                    print_green('Reached goal in ' + str(n_steps[-1]) +
                                ' steps')
                if max_timesteps:
                    print_fail('Maxed out number of steps')
                    break
                print('======================================================')
                self.env.reset(goal=True)
                current_observation = copy.deepcopy(
                    self.env.get_current_observation(goal=True))
                n_attempts += 1
                if n_attempts == max_attempts:
                    break
                continue
                # break

            # Update current observation
            current_observation = copy.deepcopy(next_observation)

        end = time.time()
        print_green('Finished in time ' + str(end - start) + ' secs')
        return n_steps
예제 #8
0
    def update_state_action_value_residual_workers(self):
        if len(self.transition_buffer) == 0:
            # No incorrect transitions yet
            return

        # Sample a batch of transitions
        transitions = self._sample_transition_batch()
        # Get all the next observations as we need to query the controller
        # for their best estimate of cost-to-go
        observations_next = [
            transition['obs_next'] for transition in transitions
        ]
        batch_size = len(observations_next)

        # Split jobs among workers
        num_workers = self.args.n_workers
        if batch_size < num_workers:
            num_workers = batch_size
        num_per_worker = batch_size // num_workers
        # Put state value residual in object store
        state_value_residual_state_dict_id = ray.put(
            self.state_value_target_residual.state_dict())
        # Put kdtrees in object store
        kdtrees_serialized_id = ray.put(pickle.dumps(self.kdtrees))
        # Put feature normalizer in object store
        feature_normalizer_state_dict_id = ray.put(
            self.feature_normalizer.state_dict())
        # Put feature normalizer q in object store
        feature_normalizer_q_state_dict_id = ray.put(
            self.feature_normalizer_q.state_dict())
        # Put state action value target residual in object store
        state_action_value_residual_state_dict_id = ray.put(
            self.state_action_value_target_residual.state_dict())

        results, count = [], 0
        for worker_id in range(num_workers):
            if worker_id == num_workers - 1:
                # last worker takes the remaining load
                num_per_worker = batch_size - count

            # Set parameters
            ray.get(self.workers[worker_id].set_worker_params.remote(
                state_value_residual_state_dict_id, kdtrees_serialized_id,
                feature_normalizer_state_dict_id,
                state_action_value_residual_state_dict_id,
                feature_normalizer_q_state_dict_id))

            # send job
            results.append(self.workers[worker_id].lookahead_batch.remote(
                observations_next[count:count + num_per_worker]))
            # Increment count
            count += num_per_worker
        # Check if all observations have been accounted for
        assert count == batch_size
        # Get all targets
        results = ray.get(results)
        target_infos = [item for sublist in results for item in sublist]

        cells = [
            transition['obs']['observation'] for transition in transitions
        ]
        goal_cells = [
            transition['obs']['desired_goal'] for transition in transitions
        ]
        actions = [transition['ac'] for transition in transitions]
        ac_idxs = np.array([self.actions_index[ac] for ac in actions],
                           dtype=np.int32)
        costs = np.array([transition['cost'] for transition in transitions],
                         dtype=np.float32)
        heuristics = np.array([
            compute_heuristic(cells[i], goal_cells[i],
                              self.args.goal_threshold)
            for i in range(len(cells))
        ],
                              dtype=np.float32)
        features = np.array([
            compute_features(cells[i], goal_cells[i], self.env.carry_cell,
                             self.env.obstacle_cell_aa,
                             self.env.obstacle_cell_bb, self.args.grid_size,
                             self.env._grid_to_continuous)
            for i in range(len(cells))
        ],
                            dtype=np.float32)
        features_norm = self.feature_normalizer_q.normalize(features)

        # Get next state value
        value_next = np.array([info['best_node_f'] for info in target_infos],
                              dtype=np.float32)
        assert value_next.shape[0] == heuristics.shape[0]

        # Compute targets
        targets = costs + value_next
        residual_targets = targets - heuristics
        # Clip the residual targets such that the residual is always positive
        residual_targets = np.maximum(residual_targets, 0)
        # Clip the residual targets so that the residual is not super big
        residual_targets = np.minimum(residual_targets, 20)

        loss = self._fit_state_action_value_residual(features_norm, ac_idxs,
                                                     residual_targets)
        # Update normalizer
        self.feature_normalizer_q.update_normalizer(features)
        return loss
예제 #9
0
    def update_state_action_value_residual(self):
        if len(self.transition_buffer) == 0:
            # No transitions yet
            return
        # Sample a batch of transitions
        transitions = self._sample_transition_batch()

        cells = [
            transition['obs']['observation'] for transition in transitions
        ]
        goal_cells = [
            transition['obs']['desired_goal'] for transition in transitions
        ]
        actions = [transition['ac'] for transition in transitions]
        ac_idxs = np.array([self.actions_index[ac] for ac in actions],
                           dtype=np.int32)
        costs = np.array([transition['cost'] for transition in transitions],
                         dtype=np.float32)
        cells_next = [
            transition['obs_next']['observation'] for transition in transitions
        ]
        goal_cells_next = [
            transition['obs_next']['desired_goal']
            for transition in transitions
        ]
        heuristics = np.array([
            compute_heuristic(cells[i], goal_cells[i],
                              self.args.goal_threshold)
            for i in range(len(cells))
        ],
                              dtype=np.float32)
        heuristics_next = np.array([
            compute_heuristic(cells_next[i], goal_cells_next[i],
                              self.args.goal_threshold)
            for i in range(len(cells))
        ],
                                   dtype=np.float32)
        features = np.array([
            compute_features(cells[i], goal_cells[i], self.env.carry_cell,
                             self.env.obstacle_cell_aa,
                             self.env.obstacle_cell_bb, self.args.grid_size,
                             self.env._grid_to_continuous)
            for i in range(len(cells))
        ],
                            dtype=np.float32)
        features_norm = self.feature_normalizer_q.normalize(features)

        features_next = np.array([
            compute_features(cells_next[i], goal_cells_next[i],
                             self.env.carry_cell, self.env.obstacle_cell_aa,
                             self.env.obstacle_cell_bb, self.args.grid_size,
                             self.env._grid_to_continuous)
            for i in range(len(cells))
        ],
                                 dtype=np.float32)
        features_next_norm = self.feature_normalizer.normalize(features_next)

        # Compute next state value
        features_next_norm_tensor = torch.from_numpy(features_next_norm)
        with torch.no_grad():
            residual_next_tensor = self.state_value_target_residual(
                features_next_norm_tensor)
            residual_next = residual_next_tensor.detach().numpy().squeeze()
        value_next = residual_next + heuristics_next

        # Compute targets
        targets = costs + value_next
        residual_targets = targets - heuristics
        # Clip the residual targets such that the residual is always positive
        residual_targets = np.maximum(residual_targets, 0)
        # Clip the residual targets so that the residual is not super big
        residual_targets = np.minimum(residual_targets, 20)

        loss = self._fit_state_action_value_residual(features_norm, ac_idxs,
                                                     residual_targets)
        # Update normalizer
        self.feature_normalizer_q.update_normalizer(features)
        self.feature_normalizer.update_normalizer(features_next)

        return loss
예제 #10
0
    def update_state_value_residual(self, inflated=False):
        # Sample batch of states
        observations = self._sample_batch(inflated)
        batch_size = len(observations)

        num_workers = self.args.n_workers
        if batch_size < num_workers:
            num_workers = batch_size
        num_per_worker = batch_size // num_workers
        # Put state value target residual in object store
        state_value_residual_state_dict_id = ray.put(
            self.state_value_target_residual.state_dict())
        # Put kdtrees in object store
        kdtrees_serialized_id = ray.put(pickle.dumps(self.kdtrees))
        # Put feature normalizer in object store
        feature_normalizer_state_dict_id = ray.put(
            self.feature_normalizer.state_dict())

        if self.args.agent in ['cmaxpp', 'adaptive_cmaxpp']:
            # Put feature normalizer q in object store
            feature_normalizer_q_state_dict_id = ray.put(
                self.feature_normalizer_q.state_dict())
            # Put state action value target residual in object store
            state_action_value_residual_state_dict_id = ray.put(
                self.state_action_value_target_residual.state_dict())
        else:
            feature_normalizer_q_state_dict_id = None
            state_action_value_residual_state_dict_id = None

        if self.args.agent == 'adaptive_cmaxpp':
            # Put inflated state value target residual in object store
            inflated_state_value_residual_state_dict_id = ray.put(
                self.inflated_state_value_target_residual.state_dict())
        else:
            inflated_state_value_residual_state_dict_id = None

        if self.args.agent == 'model':
            dynamics_residual_state_dict_id = ray.put(
                self.dynamics_residual.state_dict())
            representation_normalizer_dyn_state_dict_id = ray.put(
                self.representation_normalizer_dyn.state_dict())
        else:
            dynamics_residual_state_dict_id = None
            representation_normalizer_dyn_state_dict_id = None

        if self.args.agent == 'knn':
            knn_dynamics_residuals_serialized_id = ray.put(
                pickle.dumps(self.knn_dynamics_residuals))
        else:
            knn_dynamics_residuals_serialized_id = None

        results, count = [], 0
        for worker_id in range(num_workers):
            if worker_id == num_workers - 1:
                # last worker takes the remaining load
                num_per_worker = batch_size - count

            # Set parameters
            ray.get(self.workers[worker_id].set_worker_params.remote(
                state_value_residual_state_dict_id, kdtrees_serialized_id,
                feature_normalizer_state_dict_id,
                state_action_value_residual_state_dict_id,
                feature_normalizer_q_state_dict_id,
                inflated_state_value_residual_state_dict_id,
                dynamics_residual_state_dict_id,
                knn_dynamics_residuals_serialized_id,
                representation_normalizer_dyn_state_dict_id))

            # send job
            results.append(self.workers[worker_id].lookahead_batch.remote(
                observations[count:count + num_per_worker], inflated))
            # Increment count
            count += num_per_worker
        # Check if all observations have been accounted for
        assert count == batch_size
        # Get all targets
        results = ray.get(results)
        target_infos = [item for sublist in results for item in sublist]

        cells = [
            k.obs['observation'].copy() for info in target_infos
            for k in info['closed']
        ]
        intended_goals = [
            k.obs['desired_goal'].copy() for info in target_infos
            for k in info['closed']
        ]
        assert len(cells) == len(intended_goals)
        heuristics = np.array([
            compute_heuristic(cells[i], intended_goals[i],
                              self.args.goal_threshold)
            for i in range(len(cells))
        ],
                              dtype=np.float32)
        targets = np.array([
            info['best_node_f'] - k._g for info in target_infos
            for k in info['closed']
        ],
                           dtype=np.float32)
        residual_targets = targets - heuristics
        # Clip the residual targets such that the residual is always positive
        residual_targets = np.maximum(residual_targets, 0)
        # Clip the residual targets so that the residual is not super big
        residual_targets = np.minimum(residual_targets, 20)

        # Compute features of the cell
        features = np.array([
            compute_features(cells[i], intended_goals[i], self.env.carry_cell,
                             self.env.obstacle_cell_aa,
                             self.env.obstacle_cell_bb, self.args.grid_size,
                             self.env._grid_to_continuous)
            for i in range(len(cells))
        ],
                            dtype=np.float32)
        features_norm = self.feature_normalizer.normalize(features)

        loss = self._fit_state_value_residual(features_norm, residual_targets,
                                              inflated)
        # Update target network
        # if not inflated:
        #     self._update_target_network(self.state_value_target_residual,
        #                                 self.state_value_residual)
        # else:
        #     self._update_target_network(self.inflated_state_value_target_residual,
        #                                 self.inflated_state_value_residual)
        # Update normalizer
        self.feature_normalizer.update_normalizer(features)
        return loss