Exemple #1
0
def compute_target(policy,
                   sample_batch,
                   other_agent_batches=None,
                   episode=None):
    sample_batch_ = {
        key: sample_batch[key]
        for key in
        [SampleBatch.OBS, SampleBatch.PREV_ACTIONS, SampleBatch.ACTIONS]
    }
    sample_batch_ = convert_to_torch_tensor(sample_batch_, policy.device)

    next_obs = restore_original_dimensions(
        convert_to_torch_tensor(sample_batch[SampleBatch.NEXT_OBS]),
        policy.model.obs_space, policy.framework)
    sample_batch['battle_won'] = convert_to_non_torch_type(
        next_obs['battle_won'])
    target_q_values = policy.model.q_values(sample_batch_, target=True)
    target_q_values = convert_to_non_torch_type(target_q_values)
    actions = sample_batch[SampleBatch.ACTIONS]
    actions = actions.reshape(actions.shape[:1] +
                              (policy.model.nbr_agents, -1))
    target = np.take_along_axis(target_q_values, actions, axis=-1)
    target = np.squeeze(target, -1)
    reward = np.expand_dims(sample_batch[SampleBatch.REWARDS], -1)
    gamma = policy.config['gamma']
    lambda_ = policy.config['lambda']
    y = np.zeros_like(target)
    y[-1] = reward[-1]
    for t in range(y.shape[0] - 2, -1, -1):
        y[t] = reward[t] + (
            1 - lambda_) * gamma * target[t + 1] + gamma * lambda_ * y[t + 1]

    sample_batch[Postprocessing.VALUE_TARGETS] = y

    return sample_batch
Exemple #2
0
    def compute_actions(self,
                        obs_batch,
                        state_batches=None,
                        prev_action_batch=None,
                        prev_reward_batch=None,
                        info_batch=None,
                        episodes=None,
                        explore=None,
                        timestep=None,
                        **kwargs):
        # Exploration class will take action dist, timestep, and explore, return torch/tf tensor action
        explore = explore if explore is not None else self.config["explore"]
        timestep = timestep if timestep is not None else self.global_timestep

        with torch.no_grad():
            seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
            input_dict = self._lazy_tensor_dict({
                SampleBatch.CUR_OBS: obs_batch,
                "is_training": False,
            })
            if prev_action_batch is not None:
                input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
            if prev_reward_batch is not None:
                input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
            state_batches = [
                self._convert_to_tensor(s) for s in (state_batches or [])
            ]
            # Use action_sampler_fn variant
            action_dist = dist_inputs = None
            state_out = []
            actions, logp = self.action_sampler_fn(
                self,
                self.model,
                input_dict,
                explore=explore,
                timestep=timestep)

            input_dict[SampleBatch.ACTIONS] = actions

            # Add default and custom fetches.
            extra_fetches = self.extra_action_out(input_dict, state_batches,
                                                  self.model, action_dist)
            # Action-logp and action-prob.
            if logp is not None:
                logp = convert_to_non_torch_type(logp)
                extra_fetches[SampleBatch.ACTION_PROB] = np.exp(logp)
                extra_fetches[SampleBatch.ACTION_LOGP] = logp
            # Action-dist inputs.
            if dist_inputs is not None:
                extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs
            return convert_to_non_torch_type((actions, state_out,
                                              extra_fetches))
Exemple #3
0
        def postprocess_trajectory(self,
                                   sample_batch,
                                   other_agent_batches=None,
                                   episode=None):
            if not postprocess_fn:
                return sample_batch

            # Do all post-processing always with no_grad().
            # Not using this here will introduce a memory leak (issue #6962).
            with torch.no_grad():
                return postprocess_fn(
                    self, convert_to_non_torch_type(sample_batch),
                    convert_to_non_torch_type(other_agent_batches), episode)
Exemple #4
0
    def compute_actions_from_input_dict(
            self,
            input_dict: Dict[str, TensorType],
            explore: bool = None,
            timestep: Optional[int] = None,
            **kwargs) -> \
            Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:

        explore = explore if explore is not None else self.config["explore"]
        timestep = timestep if timestep is not None else self.global_timestep

        with torch.no_grad():
            # Pass lazy (torch) tensor dict to Model as `input_dict`.
            input_dict = self._lazy_tensor_dict(input_dict)
            # Pack internal state inputs into (separate) list.
            state_batches = [
                input_dict[k] for k in input_dict.keys() if "state_in" in k[:8]
            ]
            # Calculate RNN sequence lengths.
            seq_lens = np.array([1] * len(input_dict["obs"])) \
                if state_batches else None

            actions, state_out, extra_fetches, logp = \
                self._compute_action_helper(
                    input_dict, state_batches, seq_lens, explore, timestep)

            # Leave outputs as is (torch.Tensors): Action-logp and action-prob.
            if logp is not None:
                extra_fetches[SampleBatch.ACTION_PROB] = torch.exp(logp)
                extra_fetches[SampleBatch.ACTION_LOGP] = logp

            return convert_to_non_torch_type(
                (actions, state_out, extra_fetches))
Exemple #5
0
 def extra_grad_info(self, train_batch):
     with torch.no_grad():
         if stats_fn:
             stats_dict = stats_fn(self, train_batch)
         else:
             stats_dict = TorchPolicy.extra_grad_info(self, train_batch)
         return convert_to_non_torch_type(stats_dict)
Exemple #6
0
 def get_state(self) -> Union[Dict[str, TensorType], List[TensorType]]:
     state = super().get_state()
     state["_optimizer_variables"] = []
     for i, o in enumerate(self._optimizers):
         optim_state_dict = convert_to_non_torch_type(o.state_dict())
         state["_optimizer_variables"].append(optim_state_dict)
     return state
Exemple #7
0
    def compute_actions(self,
                        obs_batch,
                        state_batches=None,
                        prev_action_batch=None,
                        prev_reward_batch=None,
                        info_batch=None,
                        episodes=None,
                        explore=None,
                        timestep=None,
                        **kwargs):

        explore = explore if explore is not None else self.config["explore"]
        timestep = timestep if timestep is not None else self.global_timestep

        with torch.no_grad():
            input_dict = self._lazy_tensor_dict({
                SampleBatch.CUR_OBS: obs_batch,
            })
            if prev_action_batch:
                input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
            if prev_reward_batch:
                input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
            state_batches = [self._convert_to_tensor(s) for s in state_batches]

            # Call the exploration before_compute_actions hook.
            self.exploration.before_compute_actions(timestep=timestep)

            model_out = self.model(input_dict, state_batches,
                                   self._convert_to_tensor([1]))
            logits, state = model_out
            action_dist = None
            actions, logp = \
                self.exploration.get_exploration_action(
                    logits, self.dist_class, self.model,
                    timestep, explore)
            input_dict[SampleBatch.ACTIONS] = actions

            extra_action_out = self.extra_action_out(input_dict, state_batches,
                                                     self.model, action_dist)
            if logp is not None:
                logp = convert_to_non_torch_type(logp)
                extra_action_out.update({
                    ACTION_PROB: np.exp(logp),
                    ACTION_LOGP: logp
                })
            return convert_to_non_torch_type(
                (actions, state, extra_action_out))
Exemple #8
0
 def extra_compute_grad_fetches(self):
     if extra_learn_fetches_fn:
         fetches = convert_to_non_torch_type(
             extra_learn_fetches_fn(self))
         # Auto-add empty learner stats dict if needed.
         return dict({LEARNER_STATS_KEY: {}}, **fetches)
     else:
         return parent_cls.extra_compute_grad_fetches(self)
Exemple #9
0
    def compute_actions(
            self,
            obs_batch: Union[List[TensorType], TensorType],
            state_batches: Optional[List[TensorType]] = None,
            prev_action_batch: Union[List[TensorType], TensorType] = None,
            prev_reward_batch: Union[List[TensorType], TensorType] = None,
            info_batch: Optional[Dict[str, list]] = None,
            episodes: Optional[List["MultiAgentEpisode"]] = None,
            explore: Optional[bool] = None,
            timestep: Optional[int] = None,
            **kwargs) -> \
            Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:

        explore = explore if explore is not None else self.config["explore"]
        timestep = timestep if timestep is not None else self.global_timestep

        with torch.no_grad():
            seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
            input_dict = self._lazy_tensor_dict({
                SampleBatch.CUR_OBS:
                np.asarray(obs_batch),
                "is_training":
                False,
            })
            if prev_action_batch is not None:
                input_dict[SampleBatch.PREV_ACTIONS] = \
                    np.asarray(prev_action_batch)
            if prev_reward_batch is not None:
                input_dict[SampleBatch.PREV_REWARDS] = \
                    np.asarray(prev_reward_batch)
            state_batches = [
                convert_to_torch_tensor(s, self.device)
                for s in (state_batches or [])
            ]
            actions, state_out, extra_fetches, logp = \
                self._compute_action_helper(
                    input_dict, state_batches, seq_lens, explore, timestep)

            # Action-logp and action-prob.
            if logp is not None:
                logp = convert_to_non_torch_type(logp)
                extra_fetches[SampleBatch.ACTION_PROB] = np.exp(logp)
                extra_fetches[SampleBatch.ACTION_LOGP] = logp

            return convert_to_non_torch_type(
                (actions, state_out, extra_fetches))
Exemple #10
0
def to_float_np_array(v: List[Any]) -> np.ndarray:
    if torch.is_tensor(v[0]):
        raise ValueError
        v = convert_to_non_torch_type(v)
    arr = np.array(v)
    if arr.dtype == np.float64:
        return arr.astype(np.float32)  # save some memory
    return arr
 def extra_action_out(self, input_dict, state_batches, model,
                      action_dist):
     with torch.no_grad():
         if extra_action_out_fn:
             stats_dict = extra_action_out_fn(
                 self, input_dict, state_batches, model, action_dist)
         else:
             stats_dict = TorchPolicy.extra_action_out(
                 self, input_dict, state_batches, model, action_dist)
         return convert_to_non_torch_type(stats_dict)
Exemple #12
0
    def compute_actions(
        self,
        obs_batch,
        state_batches=None,
        prev_action_batch=None,
        prev_reward_batch=None,
        info_batch=None,
        episodes=None,
        explore=None,
        timestep=None,
        **kwargs,
    ):
        # pylint:disable=too-many-arguments,too-many-locals
        explore = explore if explore is not None else self.config["explore"]
        timestep = timestep if timestep is not None else self.global_timestep

        input_dict = self.lazy_tensor_dict({
            SampleBatch.CUR_OBS: obs_batch,
            "is_training": False
        })
        if prev_action_batch:
            input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
        if prev_reward_batch:
            input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
        state_batches = convert_to_torch_tensor(state_batches or [],
                                                device=self.device)

        # Call the exploration before_compute_actions hook.
        self.exploration.before_compute_actions(timestep=timestep)

        dist_inputs, state_out = self._compute_module_output(
            self._unpack_observations(input_dict),
            state_batches,
            self.convert_to_tensor([1]),
        )

        # pylint:disable=not-callable
        action_dist = self.dist_class(dist_inputs, self.module)
        # pylint:enable=not-callable
        actions, logp = self.exploration.get_exploration_action(
            action_distribution=action_dist,
            timestep=timestep,
            explore=explore)
        input_dict[SampleBatch.ACTIONS] = actions

        # Add default and custom fetches.
        extra_fetches = self._extra_action_out(input_dict, state_batches,
                                               self.module, action_dist)

        if logp is not None:
            extra_fetches[SampleBatch.ACTION_PROB] = logp.exp()
            extra_fetches[SampleBatch.ACTION_LOGP] = logp

        return convert_to_non_torch_type((actions, state_out, extra_fetches))
Exemple #13
0
    def compute_actions_from_input_dict(self,
                                        input_dict,
                                        explore=None,
                                        timestep=None,
                                        episodes=None,
                                        **kwargs):

        explore = explore if explore is not None else self.config["explore"]
        with torch.no_grad():
            input_dict = self._lazy_tensor_dict(input_dict)
            # state_batches for RNN
            state_batches = [
                input_dict[k] for k in input_dict.keys() if "state_in" in k[:8]
            ]
            seq_lens = np.array(
                [1] * len(input_dict["obs"])) if state_batches else None
            self._is_recurrent = state_batches is not None and state_batches != []
            self.model.eval()
            # print(len(input_dict['obs']))

            dist_inputs, state_out = self.model(input_dict, state_batches,
                                                seq_lens)
            action_dist = self.dist_class(dist_inputs, self.model)
            # Get the exploration action from the forward results.
            actions, logp = \
                self.exploration.get_exploration_action(
                    action_distribution=action_dist,
                    timestep=timestep,
                    explore=explore)
            # add extra info to the trajectory
            extra_info = {}
            # get values from the critic after doing inference for the actions
            extra_info[SampleBatch.VF_PREDS] = self.model.value_function()
            # Action-dist inputs.
            if dist_inputs is not None:
                extra_info[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs

            # Action-logp and action-prob.
            if logp is not None:
                extra_info[SampleBatch.ACTION_PROB] = \
                    torch.exp(logp.float())
                extra_info[SampleBatch.ACTION_LOGP] = logp

            # Update our global timestep by the batch size.
            self.global_timestep += len(input_dict[SampleBatch.CUR_OBS])

            return convert_to_non_torch_type((actions, state_out, extra_info))
Exemple #14
0
    def _compute_action_helper(self, input_dict, state_batches, seq_lens,
                               explore, timestep):
        """Shared forward pass logic (w/ and w/o trajectory view API).

        Returns:
            Tuple:
                - actions, state_out, extra_fetches, logp.
        """
        explore = explore if explore is not None else self.config["explore"]
        timestep = timestep if timestep is not None else self.global_timestep
        self._is_recurrent = state_batches is not None and state_batches != []

        # Switch to eval mode.
        if self.model:
            self.model.eval()

        if self.action_sampler_fn:
            action_dist = dist_inputs = None
            actions, logp, state_out = self.action_sampler_fn(
                self,
                self.model,
                input_dict,
                state_batches,
                explore=explore,
                timestep=timestep)
        else:
            # Call the exploration before_compute_actions hook.
            self.exploration.before_compute_actions(explore=explore,
                                                    timestep=timestep)
            if self.action_distribution_fn:
                # Try new action_distribution_fn signature, supporting
                # state_batches and seq_lens.
                try:
                    dist_inputs, dist_class, state_out = \
                        self.action_distribution_fn(
                            self,
                            self.model,
                            input_dict=input_dict,
                            state_batches=state_batches,
                            seq_lens=seq_lens,
                            explore=explore,
                            timestep=timestep,
                            is_training=False)
                # Trying the old way (to stay backward compatible).
                # TODO: Remove in future.
                except TypeError as e:
                    if "positional argument" in e.args[0] or \
                            "unexpected keyword argument" in e.args[0]:
                        dist_inputs, dist_class, state_out = \
                            self.action_distribution_fn(
                                self,
                                self.model,
                                input_dict[SampleBatch.CUR_OBS],
                                explore=explore,
                                timestep=timestep,
                                is_training=False)
                    else:
                        raise e
            else:
                dist_class = self.dist_class
                dist_inputs, state_out = self.model(input_dict, state_batches,
                                                    seq_lens)

            if not (isinstance(dist_class, functools.partial)
                    or issubclass(dist_class, TorchDistributionWrapper)):
                raise ValueError(
                    "`dist_class` ({}) not a TorchDistributionWrapper "
                    "subclass! Make sure your `action_distribution_fn` or "
                    "`make_model_and_action_dist` return a correct "
                    "distribution class.".format(dist_class.__name__))
            action_dist = dist_class(dist_inputs, self.model)

            # Get the exploration action from the forward results.
            actions, logp = \
                self.exploration.get_exploration_action(
                    action_distribution=action_dist,
                    timestep=timestep,
                    explore=explore)

        input_dict[SampleBatch.ACTIONS] = actions

        # Add default and custom fetches.
        extra_fetches = self.extra_action_out(input_dict, state_batches,
                                              self.model, action_dist)

        # Action-dist inputs.
        if dist_inputs is not None:
            extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs

        # Action-logp and action-prob.
        if logp is not None:
            extra_fetches[SampleBatch.ACTION_PROB] = \
                torch.exp(logp.float())
            extra_fetches[SampleBatch.ACTION_LOGP] = logp

        # Update our global timestep by the batch size.
        self.global_timestep += len(input_dict[SampleBatch.CUR_OBS])

        return convert_to_non_torch_type((actions, state_out, extra_fetches))
Exemple #15
0
    def compute_actions(self,
                        obs_batch,
                        state_batches=None,
                        prev_action_batch=None,
                        prev_reward_batch=None,
                        info_batch=None,
                        episodes=None,
                        explore=None,
                        timestep=None,
                        **kwargs):

        explore = explore if explore is not None else self.config["explore"]
        timestep = timestep if timestep is not None else self.global_timestep

        with torch.no_grad():
            seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
            input_dict = self._lazy_tensor_dict({
                SampleBatch.CUR_OBS: np.asarray(obs_batch),
                "is_training": False,
            })
            if prev_action_batch is not None:
                input_dict[SampleBatch.PREV_ACTIONS] = \
                    np.asarray(prev_action_batch)
            if prev_reward_batch is not None:
                input_dict[SampleBatch.PREV_REWARDS] = \
                    np.asarray(prev_reward_batch)
            state_batches = [
                convert_to_torch_tensor(s) for s in (state_batches or [])
            ]

            if self.action_sampler_fn:
                action_dist = dist_inputs = None
                state_out = []
                actions, logp = self.action_sampler_fn(
                    self,
                    self.model,
                    input_dict[SampleBatch.CUR_OBS],
                    explore=explore,
                    timestep=timestep)
            else:
                # Call the exploration before_compute_actions hook.
                self.exploration.before_compute_actions(
                    explore=explore, timestep=timestep)
                if self.action_distribution_fn:
                    dist_inputs, dist_class, state_out = \
                        self.action_distribution_fn(
                            self,
                            self.model,
                            input_dict[SampleBatch.CUR_OBS],
                            explore=explore,
                            timestep=timestep,
                            is_training=False)
                else:
                    dist_class = self.dist_class
                    dist_inputs, state_out = self.model(
                        input_dict, state_batches, seq_lens)
                if not (isinstance(dist_class, functools.partial)
                        or issubclass(dist_class, TorchDistributionWrapper)):
                    raise ValueError(
                        "`dist_class` ({}) not a TorchDistributionWrapper "
                        "subclass! Make sure your `action_distribution_fn` or "
                        "`make_model_and_action_dist` return a correct "
                        "distribution class.".format(dist_class.__name__))
                action_dist = dist_class(dist_inputs, self.model)

                # Get the exploration action from the forward results.
                actions, logp = \
                    self.exploration.get_exploration_action(
                        action_distribution=action_dist,
                        timestep=timestep,
                        explore=explore)

            input_dict[SampleBatch.ACTIONS] = actions

            # Add default and custom fetches.
            extra_fetches = self.extra_action_out(input_dict, state_batches,
                                                  self.model, action_dist)
            # Action-logp and action-prob.
            if logp is not None:
                logp = convert_to_non_torch_type(logp)
                extra_fetches[SampleBatch.ACTION_PROB] = np.exp(logp)
                extra_fetches[SampleBatch.ACTION_LOGP] = logp
            # Action-dist inputs.
            if dist_inputs is not None:
                extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs
            return convert_to_non_torch_type((actions, state_out,
                                              extra_fetches))
Exemple #16
0
    def compute_actions_from_input_dict(self,
                                        input_dict,
                                        explore=None,
                                        timestep=None,
                                        **kwargs):

        explore = explore if explore is not None else self.config["explore"]
        timestep = timestep if timestep is not None else self.global_timestep

        with torch.no_grad():
            # Pass lazy (torch) tensor dict to Model as `input_dict`.
            input_dict = self._lazy_tensor_dict(input_dict)
            # Pack internal state inputs into (separate) list.
            state_batches = [
                input_dict[k] for k in input_dict.keys() if "state_in" in k[:8]
            ]
            # Calculate RNN sequence lengths.
            seq_lens = np.array([1] * len(input_dict["obs"])) \
                if state_batches else None

            self._is_recurrent = state_batches is not None and state_batches != []

            # Switch to eval mode.
            self.model.eval()

            infos = input_dict[
                SampleBatch.INFOS] if SampleBatch.INFOS in input_dict else {}

            valid_action_trees = []
            for info in infos:
                if isinstance(info, dict) and 'valid_action_tree' in info:
                    valid_action_trees.append(info['valid_action_tree'])
                else:
                    valid_action_trees.append({})

            extra_fetches = {}

            actions_per_step = self.config['actions_per_step']
            autoregressive_actions = self.config['autoregressive_actions']

            step_actions_list = []
            step_masked_logits_list = []
            step_logp_list = []
            step_mask_list = []

            observation_features, state_out = self.model.observation_features_module(
                input_dict, state_batches, seq_lens)
            action_features, _ = self.model.action_features_module(
                input_dict, state_batches, seq_lens)

            embedded_action = None
            for a in range(actions_per_step):
                if autoregressive_actions:
                    if a == 0:
                        batch_size = action_features.shape[0]
                        previous_action = torch.zeros([
                            batch_size,
                            len(self.model.action_space_parts)
                        ]).to(action_features.device)
                    else:
                        previous_action = actions

                    embedded_action = self.model.embed_action_module(
                        previous_action)

                dist_inputs = self.model.action_module(action_features,
                                                       embedded_action)

                exploration = TorchAutoCATExploration(
                    self.model,
                    dist_inputs,
                    valid_action_trees,
                )

                actions, masked_logits, logp, mask = exploration.get_actions_and_mask(
                )

                # Remove the performed action from the trees
                for batch_action, batch_tree in zip(actions,
                                                    valid_action_trees):
                    x = int(batch_action[0])
                    y = int(batch_action[1])
                    # Assuming we have x,y coordinates
                    del batch_tree[x][y]
                    if len(batch_tree[x]) == 0:
                        del batch_tree[x]

                step_actions_list.append(actions)
                step_masked_logits_list.append(masked_logits)
                step_logp_list.append(logp)
                step_mask_list.append(mask)

            step_actions = tuple(step_actions_list)
            step_masked_logits = torch.hstack(step_masked_logits_list)
            step_logp = torch.sum(torch.stack(step_logp_list, dim=1), dim=1)
            step_mask = torch.hstack(step_mask_list)

            extra_fetches.update({'invalid_action_mask': step_mask})

            input_dict[SampleBatch.ACTIONS] = step_actions

            extra_fetches.update({
                SampleBatch.ACTION_DIST_INPUTS:
                step_masked_logits,
                SampleBatch.ACTION_PROB:
                torch.exp(step_logp.float()),
                SampleBatch.ACTION_LOGP:
                step_logp,
            })

            # Update our global timestep by the batch size.
            self.global_timestep += len(input_dict[SampleBatch.CUR_OBS])

            return convert_to_non_torch_type(
                (step_actions, state_out, extra_fetches))
Exemple #17
0
    def compute_actions_from_input_dict(self,
                                        input_dict,
                                        explore=None,
                                        timestep=None,
                                        **kwargs):

        explore = explore if explore is not None else self.config["explore"]
        timestep = timestep if timestep is not None else self.global_timestep

        with torch.no_grad():
            # Pass lazy (torch) tensor dict to Model as `input_dict`.
            input_dict = self._lazy_tensor_dict(input_dict)
            # Pack internal state inputs into (separate) list.
            state_batches = [
                input_dict[k] for k in input_dict.keys() if "state_in" in k[:8]
            ]
            # Calculate RNN sequence lengths.
            seq_lens = np.array([1] * len(input_dict["obs"])) \
                if state_batches else None

            self._is_recurrent = state_batches is not None and state_batches != []

            # Switch to eval mode.
            self.model.eval()

            dist_inputs, state_out = self.model(input_dict, state_batches,
                                                seq_lens)

            generate_valid_action_trees = self.config['env_config'].get(
                'generate_valid_action_trees', False)
            invalid_action_masking = self.config["env_config"].get(
                "invalid_action_masking", 'none')
            allow_nop = self.config["env_config"].get("allow_nop", False)

            extra_fetches = {}

            if generate_valid_action_trees:
                infos = input_dict[
                    SampleBatch.
                    INFOS] if SampleBatch.INFOS in input_dict else {}

                valid_action_trees = []
                for info in infos:
                    if isinstance(info, dict) and 'valid_action_tree' in info:
                        valid_action_trees.append(info['valid_action_tree'])
                    else:
                        valid_action_trees.append({})

                exploration = TorchConditionalMaskingExploration(
                    self.model, dist_inputs, valid_action_trees, explore,
                    invalid_action_masking, allow_nop)

                actions, masked_logits, logp, mask = exploration.get_actions_and_mask(
                )

                extra_fetches.update({'invalid_action_mask': mask})
            else:
                action_dist = self.dist_class(dist_inputs, self.model)

                # Get the exploration action from the forward results.
                actions, logp = \
                    self.exploration.get_exploration_action(
                        action_distribution=action_dist,
                        timestep=timestep,
                        explore=explore)

                masked_logits = dist_inputs

            input_dict[SampleBatch.ACTIONS] = actions

            extra_fetches.update({
                SampleBatch.ACTION_DIST_INPUTS:
                masked_logits,
                SampleBatch.ACTION_PROB:
                torch.exp(logp.float()),
                SampleBatch.ACTION_LOGP:
                logp,
            })

            # Update our global timestep by the batch size.
            self.global_timestep += len(input_dict[SampleBatch.CUR_OBS])

            return convert_to_non_torch_type(
                (actions, state_out, extra_fetches))
Exemple #18
0
    def compute_actions(self,
                        obs_batch,
                        state_batches=None,
                        prev_action_batch=None,
                        prev_reward_batch=None,
                        info_batch=None,
                        episodes=None,
                        explore=None,
                        timestep=None,
                        **kwargs):

        explore = explore if explore is not None else self.config["explore"]
        timestep = timestep if timestep is not None else self.global_timestep

        with torch.no_grad():
            seq_lens = torch.ones(len(obs_batch), dtype=torch.int32)
            input_dict = self._lazy_tensor_dict({
                SampleBatch.CUR_OBS: obs_batch,
                "is_training": False,
            })
            if prev_action_batch is not None:
                input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch
            if prev_reward_batch is not None:
                input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch
            state_batches = [
                self._convert_to_tensor(s) for s in (state_batches or [])
            ]

            if self.action_sampler_fn:
                action_dist = dist_inputs = None
                state_out = []
                actions, logp = self.action_sampler_fn(
                    self,
                    self.model,
                    input_dict[SampleBatch.CUR_OBS],
                    explore=explore,
                    timestep=timestep)
            else:
                # Call the exploration before_compute_actions hook.
                self.exploration.before_compute_actions(explore=explore,
                                                        timestep=timestep)
                if self.action_distribution_fn:
                    dist_inputs, dist_class, state_out = \
                        self.action_distribution_fn(
                            self,
                            self.model,
                            input_dict[SampleBatch.CUR_OBS],
                            explore=explore,
                            timestep=timestep,
                            is_training=False)
                else:
                    dist_class = self.dist_class
                    dist_inputs, state_out = self.model(
                        input_dict, state_batches, seq_lens)
                action_dist = dist_class(dist_inputs, self.model)

                # Get the exploration action from the forward results.
                actions, logp, unsquashed_actions = \
                    self.exploration.get_exploration_action(
                        action_distribution=action_dist,
                        timestep=timestep,
                        explore=explore)

            input_dict[SampleBatch.ACTIONS] = actions

            # Add default and custom fetches.
            extra_fetches = self.extra_action_out(input_dict, state_batches,
                                                  self.model, action_dist)
            # Action-logp and action-prob.
            if logp is not None:
                logp = convert_to_non_torch_type(logp)
                extra_fetches[SampleBatch.ACTION_PROB] = np.exp(logp)
                extra_fetches[SampleBatch.ACTION_LOGP] = logp
            unsquashed_actions = convert_to_non_torch_type(unsquashed_actions)
            extra_fetches[SampleBatch.UNSQUASHED_ACTIONS] = unsquashed_actions
            # Action-dist inputs.
            if dist_inputs is not None:
                extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs
            return convert_to_non_torch_type(
                (actions, state_out, extra_fetches))
Exemple #19
0
 def _convert_to_non_torch_type(self, data):
     if self.framework == "torch":
         return convert_to_non_torch_type(data)
     return data
Exemple #20
0
    def _compute_action_helper(self, input_dict, state_batches, seq_lens,
                               explore, timestep):
        """Shared forward pass logic (w/ and w/o trajectory view API).

        Returns:
            Tuple:
                - actions, state_out, extra_fetches, logp.
        """
        if self.action_sampler_fn:
            action_dist = dist_inputs = None
            state_out = state_batches
            actions, logp, state_out = self.action_sampler_fn(
                self,
                self.model,
                input_dict,
                state_out,
                explore=explore,
                timestep=timestep)
        else:
            # Call the exploration before_compute_actions hook.
            self.exploration.before_compute_actions(explore=explore,
                                                    timestep=timestep)
            if self.action_distribution_fn:
                dist_inputs, dist_class, state_out = \
                    self.action_distribution_fn(
                        self,
                        self.model,
                        input_dict[SampleBatch.CUR_OBS],
                        explore=explore,
                        timestep=timestep,
                        is_training=False)
            else:
                dist_class = self.dist_class
                dist_inputs, state_out = self.model(input_dict, state_batches,
                                                    seq_lens)

            if not (isinstance(dist_class, functools.partial)
                    or issubclass(dist_class, TorchDistributionWrapper)):
                raise ValueError(
                    "`dist_class` ({}) not a TorchDistributionWrapper "
                    "subclass! Make sure your `action_distribution_fn` or "
                    "`make_model_and_action_dist` return a correct "
                    "distribution class.".format(dist_class.__name__))
            action_dist = dist_class(dist_inputs, self.model)

            # Get the exploration action from the forward results.
            actions, logp = \
                self.exploration.get_exploration_action(
                    action_distribution=action_dist,
                    timestep=timestep,
                    explore=explore)

        input_dict[SampleBatch.ACTIONS] = actions

        # Add default and custom fetches.
        extra_fetches = self.extra_action_out(input_dict, state_batches,
                                              self.model, action_dist)

        # Action-dist inputs.
        if dist_inputs is not None:
            extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs

        # Action-logp and action-prob.
        if logp is not None:
            extra_fetches[SampleBatch.ACTION_PROB] = \
                torch.exp(logp.float())
            extra_fetches[SampleBatch.ACTION_LOGP] = logp

        # Update our global timestep by the batch size.
        self.global_timestep += len(input_dict[SampleBatch.CUR_OBS])

        return convert_to_non_torch_type((actions, state_out, extra_fetches))
Exemple #21
0
 def get_weights(self) -> dict:
     return {
         "module": convert_to_non_torch_type(self.module.state_dict()),
         # Optimizer state dicts don't store tensors, only ids
         "optimizers": self.optimizers.state_dict(),
     }