Exemplo n.º 1
0
    def step(self, time_step, state):
        """Returns the action to be taken.

      Args:
        time_step: an instance of rl_environment.TimeStep.
        state: should be able to recover this from time_step but dont know how.. therefore we just add this argument

      Returns:
        A `rl_agent.StepOutput` containing the action probs and actions.
    """
        # state = time_step.observations["info_state"][self._player_id]
        legal_actions = time_step.observations["legal_actions"][
            self._player_id]

        # Prevent undefined errors if this agent never plays until terminal step
        action, probs = None, None

        # Act step: don't act at terminal states.
        if not time_step.last():
            actions, probs = self._matrix_game(state)
            probs = abs(
                np.array(probs).flatten()
            )  # convert to np.array and make sure they are positive (small negative outputs)
            probs = probs / sum(
                probs)  # make sure they are properly normalized
            action = np.random.choice(actions, p=probs)

        return rl_agent.StepOutput(action=action, probs=probs)
Exemplo n.º 2
0
    def step(self, time_step, is_evaluation=False):
        """Returns the action to be taken and updates the Q-values if needed.

            Args:
            time_step: an instance of rl_environment.TimeStep.
            is_evaluation: bool, whether this is a training or evaluation call.

            Returns:
            A `rl_agent.StepOutput` containing the action probs and chosen action.
        """
        action, info_state, legal_actions, probs = self._step_action_selection(is_evaluation, time_step)

        # Learn step: don't learn during evaluation or at first agent steps.
        if self._prev_info_state and not is_evaluation:
            self._exploration = max(self._exploration * self._exploration_annealing, self._exploration_min)       #EXPLORATION ANNEALING
            target = time_step.rewards[self._player_id]                                                           #target = REWARD + DISCOUNT * MAX Q_VALUE[LEGAL ACTIONS]
            #IN A ONE-SHOT GAME FUTURE REWARDS WILL NEVER BE USED
            if not time_step.last(): # Q values are zero for terminal.
                target += self._discount_factor * max(
                    [self._q_values[info_state][a] for a in self._num_actions])

            prev_q_value = self._q_values[self._prev_info_state][self._prev_action]
            self._last_loss_value = target - prev_q_value                                                         #last_loss_value = target - Q(t)[prev_action]
            self._train_update()

            if time_step.last():  # prepare for the next episode.
                self._prev_info_state = None
                return

        # Don't mess up with the state during evaluation.
        if not is_evaluation:
            self._prev_info_state = info_state
            self._prev_action = action
        return rl_agent.StepOutput(action=action, probs=probs)
    def step(self, time_step, is_evaluation=False):
        # If it is the end of the episode, don't select an action.
        if time_step.last():
            return

        # Pick a random legal action.
        cur_legal_actions = time_step.observations["legal_actions"][
            self._player_id]
        if self._preferred_action in cur_legal_actions:
            probs = np.zeros(self._num_actions)
            probs[self._preferred_action] = 1
            return rl_agent.StepOutput(action=self._preferred_action,
                                       probs=probs)
        else:
            action = np.random.choice(cur_legal_actions)
            probs = np.zeros(self._num_actions)
            probs[cur_legal_actions] = 1.0 / len(cur_legal_actions)

        return rl_agent.StepOutput(action=action, probs=probs)
Exemplo n.º 4
0
    def step(self, time_step, is_evaluation=False):
        """Returns the action to be taken and updates the Q-networks if needed.

    Args:
      time_step: an instance of rl_environment.TimeStep.
      is_evaluation: bool, whether this is a training or evaluation call.

    Returns:
      A `rl_agent.StepOutput` containing the action probs and chosen action.
    """
        if self._mode == MODE.best_response:
            agent_output = self._rl_agent.step(time_step, is_evaluation)
            if not is_evaluation and not time_step.last():
                self._add_transition(time_step, agent_output)

        elif self._mode == MODE.average_policy:
            # Act step: don't act at terminal info states.
            if not time_step.last():
                info_state = time_step.observations["info_state"][
                    self.player_id]
                legal_actions = time_step.observations["legal_actions"][
                    self.player_id]
                try:
                    action, probs = self._act(info_state, legal_actions)
                except:
                    print(info_state)
                agent_output = rl_agent.StepOutput(action=action, probs=probs)

            if self._prev_timestep and not is_evaluation:
                self._rl_agent.add_transition(self._prev_timestep,
                                              self._prev_action, time_step)
        else:
            raise ValueError("Invalid mode ({})".format(self._mode))

        if not is_evaluation:
            self._step_counter += 1

            if self._step_counter % self._learn_every == 0:
                self._last_sl_loss_value = self._learn()
                # If learn step not triggered by rl policy, learn.
                if self._mode == MODE.average_policy:
                    self._rl_agent.learn()

            # Prepare for the next episode.
            if time_step.last():
                self._sample_episode_policy()
                self._prev_timestep = None
                self._prev_action = None
                return
            else:
                self._prev_timestep = time_step
                self._prev_action = agent_output.action

        return agent_output
Exemplo n.º 5
0
    def step(self, time_step, is_evaluation=False, add_transition_record=True):
        """Returns the action to be taken and updates the Q-network if needed.

    Args:
      time_step: an instance of rl_environment.TimeStep.
      is_evaluation: bool, whether this is a training or evaluation call.
      add_transition_record: Whether to add to the replay buffer on this step.

    Returns:
      A `rl_agent.StepOutput` containing the action probs and chosen action.
    """

        # Act step: don't act at terminal info states or if its not our turn.
        if (not time_step.last()) and (time_step.is_simultaneous_move()
                                       or self.player_id
                                       == time_step.current_player()):
            info_state = time_step.observations["info_state"][self.player_id]
            legal_actions = time_step.observations["legal_actions"][
                self.player_id]
            epsilon = self._get_epsilon(is_evaluation)
            action, probs = self._epsilon_greedy(info_state, legal_actions,
                                                 epsilon)
        else:
            action = None
            probs = []

        # Don't mess up with the state during evaluation.
        if not is_evaluation:
            self._step_counter += 1

            if self._step_counter % self._learn_every == 0:
                self._last_loss_value = self.learn()

            if self._step_counter % self._update_target_network_every == 0:
                # state_dict method returns a dictionary containing a whole state of the
                # module.
                self.params_target_q_network = jax.tree_multimap(
                    lambda x: x.copy(), self.params_q_network)

            if self._prev_timestep and add_transition_record:
                # We may omit record adding here if it's done elsewhere.
                self.add_transition(self._prev_timestep, self._prev_action,
                                    time_step)

            if time_step.last():  # prepare for the next episode.
                self._prev_timestep = None
                self._prev_action = None
                return
            else:
                self._prev_timestep = time_step
                self._prev_action = action

        return rl_agent.StepOutput(action=action, probs=probs)
Exemplo n.º 6
0
    def step(self, time_step, is_evaluation=False):
        """Returns the action to be taken and updates the Q-values if needed.

    Args:
      time_step: an instance of rl_environment.TimeStep.
      is_evaluation: bool, whether this is a training or evaluation call.

    Returns:
      A `rl_agent.StepOutput` containing the action probs and chosen action.
    """
        if self._centralized:
            info_state = str(time_step.observations["info_state"])
        else:
            info_state = str(
                time_step.observations["info_state"][self._player_id])
        legal_actions = time_step.observations["legal_actions"][
            self._player_id]

        # Prevent undefined errors if this agent never plays until terminal step
        action, probs = None, None

        # Act step: don't act at terminal states.
        if not time_step.last():
            epsilon = 0.0 if is_evaluation else self._epsilon
            action, probs = self._epsilon_greedy(info_state,
                                                 legal_actions,
                                                 epsilon=epsilon)

        # Learn step: don't learn during evaluation or at first agent steps.
        if self._prev_info_state and not is_evaluation:
            target = time_step.rewards[self._player_id]
            if not time_step.last():  # Q values are zero for terminal.
                target += self._discount_factor * max(
                    [self._q_values[info_state][a] for a in legal_actions])

            prev_q_value = self._q_values[self._prev_info_state][
                self._prev_action]
            self._last_loss_value = target - prev_q_value
            self._q_values[self._prev_info_state][self._prev_action] += (
                self._step_size * self._last_loss_value)

            # Decay epsilon, if necessary.
            self._epsilon = self._epsilon_schedule.step()

            if time_step.last():  # prepare for the next episode.
                self._prev_info_state = None
                return

        # Don't mess up with the state during evaluation.
        if not is_evaluation:
            self._prev_info_state = info_state
            self._prev_action = action
        return rl_agent.StepOutput(action=action, probs=probs)
Exemplo n.º 7
0
    def step(self, time_step, is_evaluation=False):
        # If it is the end of the episode, don't select an action.
        if time_step.last():
            return

        # Pick the minimal legal action.
        cur_legal_actions = time_step.observations["legal_actions"][
            self._player_id]
        action = np.min(cur_legal_actions)
        probs = np.zeros(self._num_actions)
        probs[action] = 1.0

        return rl_agent.StepOutput(action=action, probs=probs)
Exemplo n.º 8
0
    def step(self, time_step, is_evaluation=False):
        """Returns the action to be taken and updates the network if needed.

    Args:
      time_step: an instance of rl_environment.TimeStep.
      is_evaluation: bool, whether this is a training or evaluation call.

    Returns:
      A `rl_agent.StepOutput` containing the action probs and chosen action.
    """
        # Act step: don't act at terminal info states or if its not our turn.
        if (not time_step.last()) and (time_step.is_simultaneous_move()
                                       or self.player_id
                                       == time_step.current_player()):
            info_state = time_step.observations["info_state"][self.player_id]
            legal_actions = time_step.observations["legal_actions"][
                self.player_id]
            action, probs = self._act(info_state, legal_actions)
        else:
            action = None
            probs = []

        if not is_evaluation:
            self._step_counter += 1

            # Add data points to current episode buffer.
            if self._prev_time_step:
                self._add_transition(time_step)

            # Episode done, add to dataset and maybe learn.
            if time_step.last():
                self._add_episode_data_to_dataset()
                self._episode_counter += 1

                if len(self._dataset["returns"]) >= self._batch_size:
                    self._critic_update()
                    self._num_learn_steps += 1
                    if self._num_learn_steps % self._num_critic_before_pi == 0:
                        self._pi_update()
                    self._dataset = collections.defaultdict(list)

                self._prev_time_step = None
                self._prev_action = None
                return
            else:
                self._prev_time_step = time_step
                self._prev_action = action

        return rl_agent.StepOutput(action=action, probs=probs)
Exemplo n.º 9
0
    def step(self, time_step, is_evaluation=False, noise=None):
        if (not time_step.last()) and (time_step.is_simultaneous_move()
                                       or self.player_id
                                       == time_step.current_player()):
            # info_state has shape (dim,).
            info_state = time_step.observations["info_state"][self.player_id]
            legal_actions = time_step.observations["legal_actions"][
                self.player_id]
            action, probs = self._act(info_state, legal_actions, is_evaluation,
                                      noise)
        else:
            action = None
            probs = []

        return rl_agent.StepOutput(action=action, probs=probs)
Exemplo n.º 10
0
    def step(self, time_step, **kargs):
        player_id = time_step.observations["current_player"]
        legal_actions = time_step.observations["legal_actions"][player_id]

        num_legal_actions = len(legal_actions)
        num_actions = self.env.action_spec()["num_actions"]

        probs = np.zeros(num_actions)
        if num_legal_actions != 0:
            probs[legal_actions] = 1 / num_legal_actions
        else:
            raise ValueError("The number of legal actions is zero.")

        action = np.random.choice(num_actions, p=probs)

        return rl_agent.StepOutput(action=action, probs=probs)
Exemplo n.º 11
0
    def step(self, time_step, is_evaluation=False, add_transition_record=True):
        """Returns the action to be taken and updates the Q-network if needed.

    Args:
      time_step: an instance of rl_environment.TimeStep.
      is_evaluation: bool, whether this is a training or evaluation call.
      add_transition_record: Whether to add to the replay buffer on this step.

    Returns:
      A `rl_agent.StepOutput` containing the action probs and chosen action.
    """
        # Act step: don't act at terminal info states.
        if not time_step.last():
            info_state = time_step.observations["info_state"][self.player_id]
            legal_actions = time_step.observations["legal_actions"][
                self.player_id]
            epsilon = self._get_epsilon(is_evaluation)
            action, probs = self._epsilon_greedy(info_state, legal_actions,
                                                 epsilon)

        # Don't mess up with the state during evaluation.
        if not is_evaluation:
            self._step_counter += 1

            if self._step_counter % self._learn_every == 0:
                self._last_loss_value = self.learn()

            if self._step_counter % self._update_target_network_every == 0:
                self._session.run(self._update_target_network)

            if self._prev_timestep and add_transition_record:
                # We may omit record adding here if it's done elsewhere.
                self.add_transition(self._prev_timestep, self._prev_action,
                                    time_step)

            if time_step.last():  # prepare for the next episode.
                self._prev_timestep = None
                self._prev_action = None
                return
            else:
                self._prev_timestep = time_step
                self._prev_action = action

        return rl_agent.StepOutput(action=action, probs=probs)
Exemplo n.º 12
0
    def step(self, time_step, is_evaluation=False):
        if (not time_step.last()) and (time_step.is_simultaneous_move()
                                       or self.player_id
                                       == time_step.current_player()):
            # info_state has shape (dim,).
            info_state = time_step.observations["info_state"][self.player_id]
            legal_actions = time_step.observations["legal_actions"][
                self.player_id]
            action, probs = self._act(info_state, legal_actions, is_evaluation)

        else:
            action = None
            probs = []

        if not is_evaluation:
            # Add data points to current episode buffer.
            if self._prev_time_step:
                self._add_transition(time_step)

            # Episode done, add to dataset and maybe learn.
            if time_step.last():
                self._add_episode_data_to_dataset()

                direction = self._current_policy_idx // self._nb_directions
                delta_idx = self._current_policy_idx % self._nb_directions
                if direction == 0:
                    self._pos_rew[delta_idx] = self._dataset["returns"]
                    self._dataset = collections.defaultdict(list)
                elif direction == 1:
                    self._neg_rew[delta_idx] = self._dataset["returns"]
                    self._dataset = collections.defaultdict(list)
                else:
                    raise ValueError(
                        "Number of directions tried beyond scope.")

                self.deltas_iterator()
                self._prev_time_step = None
                self._prev_action = None
                return
            else:
                self._prev_time_step = time_step
                self._prev_action = action

        return rl_agent.StepOutput(action=action, probs=probs)
Exemplo n.º 13
0
    def step(self, time_step, is_evaluation=False):
        action, info_state, legal_actions, probs = self._step_action_selection(
            is_evaluation, time_step)

        # Learn step: don't learn during evaluation or at first agent step.
        if self._prev_info_state and not is_evaluation:
            self._k_rewards.append(time_step.rewards[self._player_id])
            self._k_actions.append(self._prev_action)
            self._k_probs.append(self._prev_probs)
            if len(self._k_actions) == self._k:
                self._exploration = max(
                    self._exploration * self._exploration_annealing,
                    self._exploration_min)  #EXPLORATION ANNEALING
                max_rew = np.amax(self._k_rewards)
                best_index = np.where(self._k_rewards == max_rew)[0][0]
                self._prev_action = self._k_actions[best_index]
                self._prev_probs = self._k_probs[best_index]
                target = self._k_rewards[best_index]
                if not time_step.last():  # no legal actions in last timestep
                    target += self._discount_factor * max(
                        [self._q_values[info_state][a] for a in legal_actions])

                prev_q_value = self._q_values[self._prev_info_state][
                    self._prev_action]
                self._last_loss_value = target - prev_q_value  #last_loss_value = target - Q(t)[prev_action]
                self._train_update()

                #reset all k_arrays
                self._k_actions = []
                self._k_rewards = []
                self._k_probs = []

            if time_step.last():  # prepare for the next episode.
                self._prev_info_state = None
                return

        # Don't mess up with the state during evaluation.
        if not is_evaluation:
            self._prev_info_state = info_state
            self._prev_action = action
            self._prev_probs = probs
        return rl_agent.StepOutput(action=action, probs=probs)
Exemplo n.º 14
0
    def step(self, time_step, is_evaluation=False):

        info_state = str(time_step.observations["info_state"][self._player_id])

        action, probs, reward = None, None, None
        if not time_step.last():
            #Fix errors in accuracy
            policy = np.array(self._policy)
            policy[policy <= 0] = 0
            policy /= sum(policy)
            action = np.random.choice(range(self._num_actions), p=policy)
            probs = self._policy

        if self._prev_info_state and not is_evaluation:
            probs = self._policy
            reward = time_step.rewards[self._player_id]

            action = np.random.choice(range(self._num_actions), p=probs)
            for a in range(self._num_actions):
                if a == self._prev_action:
                    self._policy[a] = probs[a] + self._learning_rate * (
                        reward - probs[a] * reward)
                else:
                    self._policy[a] = probs[a] + self._learning_rate * (
                        -reward * probs[a])
            if time_step.last():  # prepare for the next episode.
                self._prev_info_state = None
                return

        # Don't mess up with the state during evaluation.
        if not is_evaluation:
            self._prev_info_state = info_state
            self._prev_action = action

        return rl_agent.StepOutput(action=action, probs=probs)

        pass
Exemplo n.º 15
0
    def step(self, time_step, is_evaluation=False):
        """Returns the action to be taken and updates the value functions.

    Args:
      time_step: an instance of rl_environment.TimeStep.
      is_evaluation: bool, whether this is a training or evaluation call.

    Returns:
      A `rl_agent.StepOutput` containing the action probs and chosen action.
    """
        # Act step: don't act at terminal info states.
        if not time_step.last():
            info_state = time_step.observations["info_state"][self.player_id]
            legal_actions = time_step.observations["legal_actions"][
                self.player_id]
            epsilon = self._get_epsilon(self._agent.step_counter,
                                        is_evaluation)

            # Sample an action from EVA via epsilon greedy policy.
            action, probs = self._epsilon_greedy(
                self._q_eva[tuple(info_state)], legal_actions, epsilon)

        # Update Step: Only with transitions and not when evaluating.
        if (not is_evaluation and self._last_time_step is not None):
            info_state = self._last_time_step.observations["info_state"][
                self.player_id]
            legal_actions = self._last_time_step.observations["legal_actions"][
                self.player_id]
            epsilon = self._get_epsilon(self._agent.step_counter,
                                        is_evaluation)

            # Get embedding.
            self._info_state = torch.Tensor(np.expand_dims(info_state, axis=0))
            infostate_embedding = self._embedding_network(
                self._info_state).detach()[0]

            neighbours_value = self._value_buffer.knn(infostate_embedding,
                                                      MEM_KEY_NAME,
                                                      self._num_neighbours, 1)
            # collect trace values of knn from L (value buffer) .. Q_np(s_k)
            neighbours_replay = self._replay_buffer.knn(
                infostate_embedding, MEM_KEY_NAME, self._num_neighbours,
                self._trajectory_len)

            # Take a step with the parametric model and get q-values. Use embedding as
            # input to the parametric meodel.
            # TODO(author6) Recompute embeddings for buffers on learning steps.
            if self._embedding_as_parametric_input:
                last_time_step_copy = copy.deepcopy(self._last_time_step)
                last_time_step_copy.observations["info_state"][
                    self.player_id] = infostate_embedding
                self._agent.step(last_time_step_copy,
                                 add_transition_record=False)
            else:
                self._agent.step(self._last_time_step,
                                 add_transition_record=False)
            q_values = self._agent._q_network(self._info_state).detach()[0]
            # Update EVA: Q_eva = lambda q_theta(s_t) + (1-lambda) sum(Q_np(s_k, .))/K
            for a in legal_actions:
                q_theta = q_values[a]
                self._q_eva[tuple(info_state)][a] = (
                    self._lambda * q_theta + (1 - self._lambda) *
                    sum([elem[1].value
                         for elem in neighbours_value]) / self._num_neighbours)

            # Append (e,s,a,r,s') to Replay Buffer
            self._add_transition_replay(infostate_embedding, time_step)

            # update Q_np with Traces using TCP
            self._trajectory_centric_planning(neighbours_replay)

            # Append Q_np(s, a) to Value Buffer
            self._add_transition_value(
                infostate_embedding,
                self._q_np[tuple(info_state)][self._last_action])

        # Prepare for the next episode.
        if time_step.last():
            self._last_time_step = None
            self._last_action = None
            return

        self._last_time_step = time_step
        self._last_action = action
        return rl_agent.StepOutput(action=action, probs=probs)
Exemplo n.º 16
0
    def step(self, time_step, is_evaluation=False):
        # If it is the end of the episode, don't select an action.
        if time_step.last():
            return

        # Pick the higher or the minimal legal action.
        cur_legal_actions = time_step.observations["legal_actions"][
            self._player_id]
        minimal_action = np.min(cur_legal_actions)
        info_state = time_step.observations["info_state"][self._player_id]
        # if lead
        if info_state[1] == 1.0:
            action = minimal_action
            probs = np.zeros(self._num_actions)
            probs[action] = 1.0
            return rl_agent.StepOutput(action=action, probs=probs)
        # 337:337+52*4 - indices for current_trick observation
        current_trick = info_state[337:337 + 52 * 4]
        # order of players in current_trick: Us, LH, Pd, RH
        # careful! for dummy all is permuted: Us -> Pd, Pd -> Us, Lh -> Rh, Rh -> Lh
        # "Us", "LH", "Pd", "RH"
        # DECLARER
        # us = 0 everywhere
        # lh = 0, pd = 0, rh = 0 -> we are first 0 0 0 0
        # lh = 0, pd = 0, rh = 1 -> we are second 0 0 0 1
        # lh = 0, pd = 1, rh = 1 -> we are third 0 0 1 1
        # lh = 1, pd = 1, rh = 1 -> we are fourth 0 1 1 1
        # DUMMY
        # pd = 0 everywhere
        # lh = 0, us = 0, rh = 0 -> we are first 0 0 0 0
        # lh = 1, us = 0, rh = 0 -> we are second 0 1 0 0
        # lh = 1, us = 1, rh = 0 -> we are third 1 1 0 0
        # lh = 1, us = 1, rh = 1 -> we are fourth 1 1 0 1
        current_trick_us = current_trick[:52]
        current_trick_lh = current_trick[52:104]
        current_trick_pd = current_trick[104:156]
        current_trick_rh = current_trick[156:208]
        played_us = bool(np.sum(current_trick_us))
        played_lh = bool(np.sum(current_trick_lh))
        played_pd = bool(np.sum(current_trick_pd))
        played_rh = bool(np.sum(current_trick_rh))
        bools = [played_us, played_lh, played_pd, played_rh]
        first = [0, 0, 0, 0]
        declarer = [[0, 0, 0, 1], [0, 0, 1, 1], [0, 1, 1, 1]]
        dummy = [[0, 1, 0, 0], [1, 1, 0, 0], [1, 1, 0, 1]]
        if bools == first:
            action = minimal_action
        for i in range(3):
            if bools == dummy[i]:
                current_trick_us, current_trick_lh, current_trick_pd, current_trick_rh = \
                permute_players(current_trick_us, current_trick_lh, current_trick_pd, current_trick_rh)
                played_us = bool(np.sum(current_trick_us))
                played_lh = bool(np.sum(current_trick_lh))
                played_pd = bool(np.sum(current_trick_pd))
                played_rh = bool(np.sum(current_trick_rh))
                bools = [played_us, played_lh, played_pd, played_rh]
            if bools == declarer[i]:
                # second seat
                if i == 0:
                    card_rh = np.argmax(current_trick_rh)
                    action = beat_or_play_lowest(cur_legal_actions, card_rh)
                # third seat
                if i == 1:
                    card_pd = np.argmax(current_trick_pd)
                    suit_pd = card_pd % 4
                    card_rh = np.argmax(current_trick_rh)
                    suit_rh = card_rh % 4
                    if suit_pd != suit_rh:
                        # opponent didn't follow the suit so partner's card is the highest
                        action = minimal_action
                    else:
                        if card_pd > card_rh:
                            # partner's card is the highest
                            action = minimal_action
                        else:
                            # check if I can beat and beat if I can
                            action = beat_or_play_lowest(
                                cur_legal_actions, card_rh)
                # fourth seat
                if i == 2:
                    card_lh = np.argmax(current_trick_lh)
                    suit_lh = card_lh % 4
                    card_pd = np.argmax(current_trick_pd)
                    suit_pd = card_pd % 4
                    card_rh = np.argmax(current_trick_rh)
                    suit_rh = card_rh % 4
                    if suit_pd == suit_lh == suit_rh:
                        if card_pd > max(card_lh, card_rh):
                            action = minimal_action
                        else:
                            action = beat_or_play_lowest(
                                cur_legal_actions, max(card_lh, card_rh))
                    elif suit_pd == suit_lh:
                        if card_pd > card_lh:
                            action = minimal_action
                        else:
                            action = beat_or_play_lowest(
                                cur_legal_actions, card_lh)
                    elif suit_lh == suit_rh:
                        action = beat_or_play_lowest(cur_legal_actions,
                                                     max(card_lh, card_rh))
                    else:
                        action = beat_or_play_lowest(cur_legal_actions,
                                                     card_lh)

        probs = np.zeros(self._num_actions)
        probs[action] = 1.0

        return rl_agent.StepOutput(action=action, probs=probs)