Exemplo n.º 1
0
    def _process_single_beam_fetches(decode_fetches, temperature=0.):
        """ Decodes the beam fetches returned self._decode_policy() - This samples the beam to use based on a temp.
            :param decode_fetches: The fetches returned by self._decode_policy()
            :return: An ordered dict with the location as key, and an OrderProbTokenLogProbs as value
        """
        # If we get an empty list, we can't decode it
        if not decode_fetches:
            return decode_fetches

        beam_tokens, beam_log_probs = decode_fetches[:2]

        # Computing probabilities after applying temperature
        probs = np.exp(beam_log_probs - logsumexp(beam_log_probs))
        adj_probs = apply_temperature(probs, temperature=temperature).tolist()
        nb_probs = len(probs)

        # Sampling according to probs
        selected_beam_id = choice(range(nb_probs), p=assert_normalized(adj_probs))

        # Decoding that specific beam
        # Assigning probability mass equally over all orders in beam
        selected_beam_tokens = np.array([beam_tokens[selected_beam_id]])
        selected_beam_log_probs = np.zeros_like(selected_beam_tokens)
        decoded_results = OrderBasedPolicyModel._decode(selected_tokens=selected_beam_tokens,                           # pylint: disable=protected-access
                                                        log_probs=selected_beam_log_probs)['decoded_orders'][0]

        # Adjusting log probs to make it uniform over all locs
        nb_locs = len(decoded_results)
        adj_log_probs = beam_log_probs[selected_beam_id] / max(1, nb_locs)
        decoded_results = {loc: OrderProbTokenLogProbs(order=decoded_results[loc].order,
                                                       probability=decoded_results[loc].probability,
                                                       log_probs=[adj_log_probs]) for loc in decoded_results}
        return decoded_results
Exemplo n.º 2
0
    def get_beam_orders(self, locs, state_proto, power_name,
                        phase_history_proto, possible_orders_proto, **kwargs):
        """ Finds all the beams with their probabilities returned by the diverse beam search
            Beams are ordered by score (highest first).
            :param locs: A list of locations for which we want orders
            :param state_proto: A `.proto.game.State` representation of the state of the game.
            :param power_name: The power name for which we want the orders and the state values
            :param phase_history_proto: A list of `.proto.game.PhaseHistory`. This represents prev phases.
            :param possible_orders_proto: A `proto.game.PossibleOrders` object representing possible order for each loc.
            :param kwargs: Additional optional kwargs:
                - player_seed: The seed to apply to the player to compute a deterministic mask.
                - noise: The sigma of the additional noise to apply to the intermediate layers (i.e. sigma * epsilon)
                - temperature: The temperature to apply to the logits. (Default to 0. for deterministic/greedy)
                - dropout_rate: The amount of dropout to apply to the inputs/outputs of the decoder.
                - with_state_value: Boolean that indicates to also query the value function.
                - retry_on_failure: Boolean that indicates to retry querying from the model if an error is encountered.
                - prefetch: Boolean that indicates to return a dictionary of fetches (str: PrefetchedItem/Future)
                - fetches: Dictionary of (str: future_results) that was computed with prefetch=True
            :return:
                - if prefetch=True, a dictionary of fetches (key as string, value is a future (or list) to yield on)
                - if prefetch=False and with_state_value=False (default), a tuple consisting of:
                     1) A list of beams (i.e. a list of selected orders for each beam)
                     2) A list of probability (the probability of selecting each beam)
                - if prefetch=False and with_state_value=True, a tuple consisting of:
                     1) A list of beams (i.e. a list of selected orders for each beam)
                     2) A list of probability (the probability of selecting each beam)
                     3) The state value for the given state
        """
        # Determining if we need to prefetch or postfetch
        fetches = kwargs.get('fetches', {})
        is_prefetching = kwargs.get('prefetch', False)
        is_postfetching = fetches and not is_prefetching
        fetch_prefix = 'get_orders'
        with_state_value = kwargs.get('with_state_value', False)

        # Getting fetches
        if not is_postfetching:
            locs = [loc[:3] for loc in locs]

            # Running policy model
            fetches['%s/decode_fetches' % fetch_prefix] = self._decode_policy(
                locs,
                state_proto,
                power_name,
                phase_history_proto,
                possible_orders_proto,
                use_beam=True,
                **strip_keys(kwargs, ['use_beam']))
            # Prefetching - We only return the fetches
            if is_prefetching:
                return fetches

            # Otherwise, we yield on the fetches
            fetches = yield process_fetches_dict(self.feedable_dataset,
                                                 fetches)

        # Variables
        beams, adj_probs, state_value = [], [], 0.

        # Processing
        decode_fetches = fetches[
            '%s/decode_fetches' %
            fetch_prefix]  # (beam_orders, beam_log_probs, draw, value)
        if decode_fetches is None:
            return tuple([beams, adj_probs] +
                         ([state_value] if with_state_value else []))

        # Computing adj probabilities
        beam_orders, beam_log_probs = decode_fetches[:2]
        probs = np.exp(beam_log_probs - logsumexp(beam_log_probs))
        adj_probs = apply_temperature(probs, temperature=1.).tolist()

        # Decoding
        for beam_candidates in beam_orders:
            orders_for_this_candidate = []
            nb_locs = len(beam_candidates) // TOKENS_PER_ORDER

            # Decoding each token
            for loc_ix in range(nb_locs):
                order_tokens = beam_candidates[loc_ix *
                                               TOKENS_PER_ORDER:(loc_ix + 1) *
                                               TOKENS_PER_ORDER]
                order_str = ' '.join([
                    ix_to_token(token) for token in order_tokens
                    if token > EOS_ID
                ])
                if order_str:
                    orders_for_this_candidate += [order_str]
            beams += [orders_for_this_candidate]

        # Getting state value
        if with_state_value:
            state_value = decode_fetches[-1]

        # Returning
        return tuple([beams, adj_probs] +
                     ([state_value] if with_state_value else []))