Esempio n. 1
0
 def _wrap_legacy_request(request):
     """Wraps single legacy request."""
     if isinstance(request, NetworkRequest):
         return Request(RequestType.AGENT_NETWORK)
     elif isinstance(request, np.ndarray):
         return Request(RequestType.AGENT_PREDICTION, content=request)
     else:
         return None
Esempio n. 2
0
    def _batch_predict_steps(self, observations, actions):
        """Predicts next state, reward and done.

        Args:
            observations (np.ndarray): Array of shape (batch, height, width,
                channels) of one-hot encoded observations (along axis=-1).
            actions (np.ndarray): Array of shape (batch,) of actions performed
                by agents.

        Yields:
            request (Request): Model prediction request with one-hot encoded
                input states and actions; handled by RequestHandler.

        Returns:
            next_state (np.ndarray): Array of shape
                (batch, height, width, n_channels) of one-hot encoded state.
            reward (np.ndarray): Array of shape (batch,) of rewards received
                by agents.
            done (np.ndarray): Array of shape (batch,) indicates if episode
                was terminated.
        """
        assert observations.shape[1:] == self.observation_space.shape
        xs = insert_action_channels(observations, actions, self.action_space.n)
        model_outputs = yield Request(RequestType.MODEL_PREDICTION, xs)
        return self.transform_model_outputs(observations, model_outputs)
Esempio n. 3
0
    def reset(self, env, observation):
        yield from super().reset(env, observation)

        if not self._use_trainable_env:
            self._model = TrainableModelEnv.wrap_perfect_env(env)
        else:
            if self._model is None:
                # Deferred construction. The model will be reused in all
                # subsequent calls to reset().
                self._model = self._model_class(modeled_env=env)
            if self._use_model_ensembles:
                self._model.set_global_index_mask(
                    self._model_ensemble_size, self._model_ensemble_mask_size)

        # 'reset' mcts internal variables: _state2node and _model
        self._bonuses = deque([], maxlen=self._bonus_queue_length)
        self._state2node = {}
        if not self._use_trainable_env:
            state = self._model.clone_state()
        else:
            state = self._model.obs2state(observation)
        [[value]] = yield Request(RequestType.AGENT_PREDICTION,
                                  np.array([observation]))

        if self._use_ensembles:
            self._value_acc_class.set_global_index_mask(
                self._ensemble_size, self._ensemble_mask_size)

        # Initialize root.
        graph_node = self._initialize_graph_node(initial_value=value,
                                                 state=state,
                                                 done=False,
                                                 solved=False)
        self._root = TreeNode(graph_node)
Esempio n. 4
0
    def _expand_graph_node(self, node):
        assert bool(node.rewards) == bool(node.edges)
        if node.edges:
            return  # graph node is expanded already

        # neighbours are ordered in the order of actions:
        # 0, 1, ..., _model.num_actions
        observations, rewards, dones, infos, states = \
            yield from self._model.predict_steps(
                node.state,
                list(space_utils.element_iter(self._model.action_space))
            )
        solved = [info.get('solved', False) for info in infos]
        node.bonus = [info.get('bonus', 0.) for info in infos]
        self._bonuses.append(max(node.bonus))

        value_batch = yield Request(RequestType.AGENT_PREDICTION,
                                    np.array(observations))

        for idx, action in enumerate(
                space_utils.element_iter(self._action_space)):
            node.rewards[action] = rewards[idx]
            new_node = self._state2node.get(states[idx], None)
            if new_node is None:
                if dones[idx]:
                    child_value = self._value_traits.zero
                else:
                    [child_value] = value_batch[idx]
                new_node = self._initialize_graph_node(child_value,
                                                       states[idx],
                                                       dones[idx],
                                                       solved=solved[idx])
            node.edges[action] = new_node
Esempio n. 5
0
    def _batch_predict_steps(self, observations, actions):
        """Predicts next state, reward and done.

        Args:
            observations (np.ndarray): Array of shape (batch, height, width,
                channels) of one-hot encoded observations (along axis=-1).
            actions (np.ndarray): Array of shape (batch,) of actions performed
                by agents.

        Yields:
            request (Request): Model prediction request with one-hot encoded
                input states and actions; handled by RequestHandler.

        Returns:
            next_state (np.ndarray): Array of shape
                (batch, height, width, n_channels) of one-hot encoded state.
            reward (np.ndarray): Array of shape (batch,) of rewards received
                by agents.
            done (np.ndarray): Array of shape (batch,) indicates if episode
                was terminated.
        """
        assert observations.shape[1:] == self.observation_space.shape

        request_content = np.concatenate(
            (observations, np.expand_dims(actions, axis=-1)),
            axis=-1
        )
        res = yield Request(RequestType.MODEL_PREDICTION, request_content)
        next_observations, rewards, dones = res.T
        next_observations = np.stack(next_observations).astype(float)
        rewards = rewards.astype(float)
        dones = dones.astype(bool)
        infos = [{'solved': done} for done in dones]

        return next_observations, rewards, dones, infos
 def act(self, observation):
     self._observations.append(observation)
     for _ in range(self._max_n_requests):
         # End the predictions at random times.
         if random.random() < 0.5:
             break
         response = yield Request(RequestType.AGENT_PREDICTION,
                                  np.array([self._requests.pop(0)]))
         self._responses.append(response[0])
     return self._actions.pop(0), {}
Esempio n. 7
0
    def _handle_env_feedback(self, agent_info, action, next_observation,
                             reward, done, env_info):
        """Handles model's mispredictions."""

        if not self._use_trainable_env:
            # We use perfect model, so there aren't any mispredictions
            # to handle.
            return
        root_parent = agent_info['node']
        true_state = self._model.obs2state(next_observation)
        solved = env_info.get('solved', False)

        # Correct mispredicted reward.
        root_parent.rewards[action] = reward

        if self._current_node.state != true_state:
            # self._model predicted wrong state, initialize new tree from
            # the true state
            new_node = self._state2node.get(true_state, None)
            if new_node is None:
                # True next state was not visited previously.
                # Initialize new GraphNode.
                if done:
                    value = self._value_traits.zero
                else:
                    # Batch stepper requires all requests submitted at the same
                    # time to have equal shape. The only other place, which
                    # sends requests, is self._expand_leaf() method, where
                    # `n_actions` observations are sent - so we do the same
                    # here.
                    #
                    # In practice: in batch stepper allow different number of
                    # observations to be sent from different agents.
                    n_actions = space_utils.max_size(self._model.action_space)
                    response = yield Request(
                        RequestType.AGENT_PREDICTION,
                        np.array([next_observation] * n_actions))
                    [value] = response[0]  # we ignore all other responses

                new_node = self._initialize_graph_node(value, true_state, done,
                                                       solved)
            # Correct mispredicted state in GraphNode, so we won't make
            # the same mistake again.
            root_parent.edges[action] = new_node
            self._current_node = new_node

        self._current_node.terminal = done

        self._current_node.solved = solved
    def solve(self, env, epoch=None, init_state=None, time_limit=None):
        del env
        del epoch
        del init_state
        del time_limit
        for _ in range(self._max_n_requests):
            # End the predictions at random times.
            if random.random() < 0.5:
                break

            response_value = yield Request(self._request_type,
                                           self._request_value)
            self._check_assertions(response_value)

        return 0, {}
Esempio n. 9
0
    def _expand_graph_node(self, node):
        assert bool(node.rewards) == bool(node.edges)
        if (len(node.edges) > 0 or  # graph node is expanded already
                node.solved or node.terminal):
            return

        # neighbours are ordered in the order of actions:
        # 0, 1, ..., _model.num_actions
        observations, rewards, dones, infos, states = \
            yield from self._model.predict_steps(
                node.state,
                list(space_utils.element_iter(self._model.action_space))
            )
        # solved = [info.get('solved', False) for info in infos]
        assert all([reward in (0, 1) for reward in rewards]), \
            'We assume that env is deterministic, and there are goal states ' \
            'obtaining which gives you reward=1 and ends episode. All other ' \
            'actions should give reward=0'
        solved = [reward == 1 for reward in rewards]

        node.bonus = [
            self._filter_bonus(info.get('bonus', 0.), reward, done)
            for info, reward, done in zip(infos, rewards, dones)
        ]
        node.value_acc.add_bonus(max(node.bonus))
        self._bonuses.append(max(node.bonus))

        value_batch = yield Request(RequestType.AGENT_PREDICTION,
                                    np.array(observations))

        for idx, action in enumerate(
                space_utils.element_iter(self._action_space)):
            node.rewards[action] = rewards[idx]
            new_node = self._state2node.get(states[idx], None)
            if new_node is None:
                if dones[idx]:
                    child_value = self._value_traits.zero
                else:
                    [child_value] = value_batch[idx]
                new_node = self._initialize_graph_node(child_value,
                                                       states[idx],
                                                       dones[idx],
                                                       solved=solved[idx])
            node.edges[action] = new_node
        self._update_from_node(node)
Esempio n. 10
0
    def batched_request(self):
        """Batches requests and returns batched request."""
        if self._batched_request is not None:
            return self._batched_request

        data.nested_map(_PredictionRequestBatcher._assert_not_scalar,
                        self._requests)

        # Stack instead of concatenate to ensure that all requests have
        # the same shape.
        batched_request_content = data.nested_stack(
            [request.content for request in self._requests])
        # (n_agents, n_requests, ...) -> (n_agents * n_requests, ...)
        batched_request_content = data.nested_map(
            _PredictionRequestBatcher._flatten_first_2_dims,
            batched_request_content)
        self._batched_request = Request(self._request_type,
                                        batched_request_content)
        return self._batched_request
 def solve(self, _):
     for _ in range(self._n_requests):
         _, params = yield Request(RequestType.AGENT_NETWORK)
         assert params == self._xparams
     return RequestType.AGENT_NETWORK, self._xparams
 def solve(self, _):
     """Mock solve method."""
     network_fn, params = yield Request(RequestType.AGENT_NETWORK)
     assert isinstance(network_fn(), network_class)
     assert params == xparams
     return episode