Пример #1
0
    def _process_single_beam_fetches(decode_fetches, temperature=0.):
        """ Decodes the beam fetches returned self._decode_policy() - This samples the beam to use based on a temp.
            :param decode_fetches: The fetches returned by self._decode_policy()
            :return: An ordered dict with the location as key, and an OrderProbTokenLogProbs as value
        """
        # If we get an empty list, we can't decode it
        if not decode_fetches:
            return decode_fetches

        beam_tokens, beam_log_probs = decode_fetches[:2]

        # Computing probabilities after applying temperature
        probs = np.exp(beam_log_probs - logsumexp(beam_log_probs))
        adj_probs = apply_temperature(probs, temperature=temperature).tolist()
        nb_probs = len(probs)

        # Sampling according to probs
        selected_beam_id = choice(range(nb_probs), p=assert_normalized(adj_probs))

        # Decoding that specific beam
        # Assigning probability mass equally over all orders in beam
        selected_beam_tokens = np.array([beam_tokens[selected_beam_id]])
        selected_beam_log_probs = np.zeros_like(selected_beam_tokens)
        decoded_results = OrderBasedPolicyModel._decode(selected_tokens=selected_beam_tokens,                           # pylint: disable=protected-access
                                                        log_probs=selected_beam_log_probs)['decoded_orders'][0]

        # Adjusting log probs to make it uniform over all locs
        nb_locs = len(decoded_results)
        adj_log_probs = beam_log_probs[selected_beam_id] / max(1, nb_locs)
        decoded_results = {loc: OrderProbTokenLogProbs(order=decoded_results[loc].order,
                                                       probability=decoded_results[loc].probability,
                                                       log_probs=[adj_log_probs]) for loc in decoded_results}
        return decoded_results
Пример #2
0
    def expand(self, confirmed_orders, locs, state_proto, power_name,
               phase_history_proto, possible_orders_proto, **kwargs):
        """ Computes the conditional probability of possible orders for each loc given the confirmed orders.
            :param confirmed_orders: The list of orders on which to condition the probs (e.g. ['A PAR H', 'A MAR - SPA']
            :param locs: The locations for which we want to compute probabilities
            :param state_proto: A `.proto.game.State` representation of the state of the game.
            :param power_name: The power name for which we want the probabilities
            :param phase_history_proto: A list of `.proto.game.PhaseHistory`. This represents prev phases.
            :param possible_orders_proto: A `proto.game.PossibleOrders` object representing possible order for each loc.
            :param kwargs: Additional optional kwargs:
                - player_seed: The seed to apply to the player to compute a deterministic mask.
                - noise: The sigma of the additional noise to apply to the intermediate layers (i.e. sigma * epsilon)
                - temperature: The temperature to apply to the logits. (Default to 0. for deterministic/greedy)
                - dropout_rate: The amount of dropout to apply to the inputs/outputs of the decoder.
                - retry_on_failure: Boolean that indicates to retry querying from the model if an error is encountered.
                - prefetch: Boolean that indicates to return a dictionary of fetches (str: PrefetchedItem/Future)
                - fetches: Dictionary of (str: future_results) that was computed with prefetch=True
            :return:
                - if prefetch=True, a dictionary of fetches (key as string, value is a future (or list) to yield on)
                - if prefetch=False,
                    A dictionary with every location in locs as key, and a list of tuples where each tuple is composed
                     of 1) an order, 2) the order conditional probability, 3) the conditional log probs of each token
                        e.g. {'PAR': [('A PAR H', 0.000, [...]), ('A PAR - BUR', 0.000, [...]), ...]}
        """
        # pylint: disable=too-many-nested-blocks
        # Determining if we need to prefetch or postfetch
        fetches = kwargs.get('fetches', {})
        is_prefetching = kwargs.get('prefetch', False)
        is_postfetching = fetches and not is_prefetching
        fetch_prefix = 'expand'

        # Locations
        locs = [loc[:3] for loc in locs]
        confirmed_locs = [order.split()[1][:3] for order in confirmed_orders]

        # Getting fetches
        if not is_postfetching:

            confirmed_tokens = []
            for order in confirmed_orders:
                confirmed_tokens += self.tokenize(order)

            # Building a list of conditional probs we want to expand
            # e.g. A PAR H, A PAR - MAR, A PAR - BUR
            # we want P('H' | 'A PAR'), P('MAR', 'A PAR -'), ...

            # 1) We compute all the prefix (RHS of the prob) and we only expand where count is > 1
            prefix_count = {}  # {prefix => count}
            for loc in locs:
                for order in possible_orders_proto[loc].value:
                    tokens = [-1 + -1 * locs.index(loc)] + self.tokenize(
                        order)  # e.g. [-2, 25, 4, 25, ...]
                    nb_tokens = len(tokens)
                    for token_ix in range(1, nb_tokens):
                        prefix = tuple(tokens[:token_ix])
                        prefix_count[prefix] = prefix_count.get(prefix, 0) + 1

            # 2) Building the list of feedable items only for probs we need to expand
            feedable_items = OrderedDict()  # {prefix => feedable_item}
            items_to_expand = OrderedDict(
            )  # {prefix => set of next available token}
            for loc in locs:  # pylint: disable=too-many-nested-blocks
                for order in possible_orders_proto[loc].value:
                    tokens = [-1 + -1 * locs.index(loc)] + self.tokenize(
                        order)  # e.g. [-2, 25, 4, 25, ...]
                    nb_tokens = len(tokens)

                    for token_ix in range(1, nb_tokens):
                        prefix = tuple(tokens[:token_ix])
                        if prefix_count.get(
                                prefix) > 1 and prefix not in feedable_items:
                            feedable_item = self.feedable_dataset.get_feedable_item(
                                confirmed_locs + [loc], state_proto,
                                power_name, phase_history_proto,
                                possible_orders_proto, **kwargs)

                            # The teacher forcing is GO_ID, the confirmed orders prefix, the actual prefix, and a dummy
                            feedable_item['decoder_inputs'] = [
                                GO_ID
                            ] + confirmed_tokens + list(prefix[1:]) + [PAD_ID]
                            feedable_item['decoder_lengths'] = len(
                                confirmed_tokens) + len(prefix[1:]) + 1

                            # Keeping a list of orders using the prefix
                            for possible_order in possible_orders_proto[
                                    loc].value:
                                new_tokens = [-1 + -1 * locs.index(loc)
                                              ] + self.tokenize(possible_order)
                                if prefix == tuple(new_tokens[:token_ix]):
                                    items_to_expand.setdefault(prefix, set())
                                    items_to_expand[prefix].add(
                                        new_tokens[token_ix])

                            # Storing feedable item
                            feedable_items[prefix] = feedable_item

            # 3) Removing items_to_expand with only 1 items
            # We know for sure the probability will be 100%
            for prefix in list(items_to_expand.keys()):
                if len(items_to_expand[prefix]) == 1:
                    del items_to_expand[prefix]
                    del feedable_items[prefix]

            # 4) Running all the feedable items
            queue_name = 'policy_expand'
            fetches['%s/items_to_expand' %
                    fetch_prefix] = CompletedFuture(items_to_expand)
            fetches['%s/results' % fetch_prefix] = [
                self.feedable_dataset.get_results(queue_name, item, **kwargs)
                for item in feedable_items.values()
            ]

            # Prefetching - We only return the fetches
            if is_prefetching:
                return fetches

            # Otherwise, we yield on the fetches
            fetches = yield process_fetches_dict(self.feedable_dataset,
                                                 fetches)

        # Processing fetches
        def softmax(softmax_logits):
            """ Compute softmax values for the logits """
            e_x = np.exp(softmax_logits -
                         softmax_logits.max(axis=-1, keepdims=True))
            return e_x / e_x.sum(axis=-1, keepdims=True)

        items_to_expand = fetches['%s/items_to_expand' % fetch_prefix]
        results = fetches['%s/results' % fetch_prefix]
        (logits, ) = zip(*results)

        # 5) Computing probs
        probs = {}  # {prefix: {loc: prob}}
        for prefix, logit in zip(items_to_expand.keys(), logits):

            # IndexError - Ignoring prefix
            if TOKENS_PER_ORDER * len(confirmed_locs) + len(prefix) - 1 >= len(
                    logit):
                LOGGER.error(
                    'Got %d logits, but trying to access logit at index %d. Ignoring prefix.',
                    len(logit),
                    TOKENS_PER_ORDER * len(confirmed_locs) + len(prefix) - 1)
                LOGGER.error('Prefix: %s - Confirmed locs: %s', prefix,
                             confirmed_locs)
                continue

            tokens_to_expand = list(sorted(items_to_expand[prefix]))
            token_logits = logit[TOKENS_PER_ORDER * len(confirmed_locs) +
                                 len(prefix) - 1]

            # Only selecting the logits that we expect
            # There is currently a bug in the tokenization that could return additional tokens (Issue #331)
            masked_logits = []
            for token in tokens_to_expand:
                masked_logits += [token_logits[token]]
            token_probs = softmax(np.array(masked_logits, dtype=np.float32))

            # Computing the correct probabilities
            probs[prefix] = {}
            for token_ix, token in enumerate(tokens_to_expand):
                probs[prefix][token] = token_probs[token_ix]

        # 6) Computing the prob of each order at each location
        results = {}
        for loc in locs:
            results[loc] = []

            # Processing each possible order
            for order in possible_orders_proto[loc].value:
                tokens = [-1 + -1 * locs.index(loc)] + self.tokenize(order)
                nb_tokens = len(tokens)
                order_prob = 1.
                order_log_probs = []

                # Computing the total order probability and each token log probs
                for token_ix in range(1, nb_tokens):
                    prefix = tuple(tokens[:token_ix])
                    if prefix in probs and tokens[token_ix] in probs[prefix]:
                        order_prob *= probs[prefix][tokens[token_ix]]
                        order_log_probs += [
                            np.log(
                                np.maximum(probs[prefix][tokens[token_ix]],
                                           1e-8))
                        ]
                    else:
                        order_log_probs += [0.]

                results[loc] += [
                    OrderProbTokenLogProbs(order=order,
                                           probability=order_prob,
                                           log_probs=order_log_probs)
                ]

            # Sorting loc by probability
            results[loc] = list(
                sorted(results[loc],
                       key=lambda item: item.probability,
                       reverse=True))

        # Returning
        return results
Пример #3
0
    def _decode(**fetches):
        """ Performs decoding on the output (order_based model)
            :param fetches: A dictionary of fetches from the model.

            Keys can include:

            - selected_tokens / argmax_tokens: [Required] The tokens from the model (Tensor [batch, decoder_length])
            - log_probs: [Required] The log probs from the model (Tensor [batch, decoder_length])
            - policy_loss: The policy loss for the batch.
            - targets: The targets from the model (Tensor [batch, length]). Required for evaluation.
            - current_power: The current_power from the model (Tensor [batch,]). Required for evaluation.
            - current_season: The current_season from the model (Tensor [batch,]). Required for evaluation.
            - in_retreat_phase: Boolean that indicates dislodged units are on the map. ([b,]). Required for evaluation.
            - request_id: The unique request id for each item in the batch.

            :return: A dictionary of decoded results, including
                - 1) decoded_orders:
                    A list of dictionary (one per batch) where each dict has location as key and a
                      OrderProbTokenLogProbs tuple as value (i.e. an order, its prob, and the token log probs)
                        e.g. [{'PAR': (order, prob, log_probs),'MAR': (order, prob, log_probs)},
                              {'PAR': (order, prob, log_probs),'MAR': (order, prob, log_probs)}]
                - 2) various other keys for evaluation
        """
        # Missing the required fetches, returning an empty decoded results
        if ('selected_tokens' not in fetches and 'argmax_tokens'
                not in fetches) or 'log_probs' not in fetches:
            return {}

        # tokens:           [batch, dec_len]
        # log_probs:        [batch, dec_len]
        # policy_loss:      ()
        # targets:          [batch, dec_len]
        # current_power:    [batch]
        # current_season:   [batch]
        # in_retreat_phase: [batch]
        # request_ids:      [batch]
        tokens = fetches.get('selected_tokens', fetches.get('argmax_tokens'))
        log_probs = fetches['log_probs']
        policy_loss = fetches.get('policy_loss', None)
        targets = fetches.get('targets', None)
        current_power = fetches.get('current_power', None)
        current_season = fetches.get('current_season', None)
        in_retreat_phase = fetches.get('in_retreat_phase', None)
        request_ids = fetches.get('request_id', None)

        # Decoding orders
        results = []
        result_tokens = []
        nb_batches = tokens.shape[0]

        for batch_ix in range(nb_batches):
            batch_results = OrderedDict()
            batch_results_tokens = OrderedDict()
            batch_tokens = tokens[batch_ix]
            batch_log_probs = log_probs[batch_ix]
            nb_waive = 0

            # We didn't try to predict orders - Skipping
            if not len(batch_tokens) or batch_tokens[0] == [0]:  # pylint: disable=len-as-condition
                results += [batch_results]
                result_tokens += [batch_results_tokens]
                continue

            for token_ix, token in enumerate(batch_tokens):
                if token <= EOS_ID:
                    continue

                order = ix_to_order(token)

                # WAIVE orders
                if order == 'WAIVE':
                    loc = 'WAIVE_{}'.format(nb_waive)
                    nb_waive += 1

                # Use normal location and skip if already stored
                else:
                    loc = order.split()[1]
                    if loc in batch_results:
                        continue
                    loc = loc[:3]

                # Storing order
                batch_results[loc] = OrderProbTokenLogProbs(
                    order=order,
                    probability=1.,
                    log_probs=[batch_log_probs[token_ix]])
                batch_results_tokens[loc] = [token]

            # Done with batch
            results += [batch_results]
            result_tokens += [batch_results_tokens]

        # Returning
        return {
            'decoded_orders': results,
            'policy_loss': policy_loss,
            'targets': targets,
            'tokens': result_tokens,
            'current_power': current_power,
            'current_season': current_season,
            'in_retreat_phase': in_retreat_phase,
            'request_id': request_ids,
            'log_probs': log_probs
        }
Пример #4
0
    def expand(self, confirmed_orders, locs, state_proto, power_name, phase_history_proto, possible_orders_proto,
               **kwargs):
        """ Computes the conditional probability of possible orders for each loc given the confirmed orders.
            :param confirmed_orders: The list of orders on which to condition the probs (e.g. ['A PAR H', 'A MAR - SPA']
            :param locs: The locations for which we want to compute probabilities
            :param state_proto: A `.proto.game.State` representation of the state of the game.
            :param power_name: The power name for which we want the probabilities
            :param phase_history_proto: A list of `.proto.game.PhaseHistory`. This represents prev phases.
            :param possible_orders_proto: A `proto.game.PossibleOrders` object representing possible order for each loc.
            :param kwargs: Additional optional kwargs:
                - player_seed: The seed to apply to the player to compute a deterministic mask.
                - noise: The sigma of the additional noise to apply to the intermediate layers (i.e. sigma * epsilon)
                - temperature: The temperature to apply to the logits. (Default to 0. for deterministic/greedy)
                - dropout_rate: The amount of dropout to apply to the inputs/outputs of the decoder.
                - retry_on_failure: Boolean that indicates to retry querying from the model if an error is encountered.
                - prefetch: Boolean that indicates to return a dictionary of fetches (str: PrefetchedItem/Future)
                - fetches: Dictionary of (str: future_results) that was computed with prefetch=True
            :return:
                - if prefetch=True, a dictionary of fetches (key as string, value is a future (or list) to yield on)
                - if prefetch=False,
                    A dictionary with every location in locs as key, and a list of tuples where each tuple is composed
                     of 1) an order, 2) the order conditional probability, 3) the conditional log probs of each token
                        e.g. {'PAR': [('A PAR H', 0.000, [...]), ('A PAR - BUR', 0.000, [...]), ...]}
        """
        # Determining if we need to prefetch or postfetch
        fetches = kwargs.get('fetches', {})
        is_prefetching = kwargs.get('prefetch', False)
        is_postfetching = fetches and not is_prefetching
        fetch_prefix = 'expand'

        # Locations
        locs = [loc[:3] for loc in locs]
        confirmed_locs = [order.split()[1][:3] for order in confirmed_orders]

        # Getting fetches
        if not is_postfetching:

            confirmed_tokens = []
            for order in confirmed_orders:
                confirmed_tokens += self.tokenize(order)

            # Building all feedable items
            feedable_items = OrderedDict()
            candidates = OrderedDict()
            for loc in locs:
                feedable_item = self.feedable_dataset.get_feedable_item(confirmed_locs + [loc],
                                                                        state_proto,
                                                                        power_name,
                                                                        phase_history_proto,
                                                                        possible_orders_proto,
                                                                        **kwargs)
                loc_candidates = [candidate for candidate in feedable_item['candidates'][-1] if candidate > PAD_ID]

                # No candidates - Can't expand
                if not loc_candidates:
                    continue

                # Setting the decoder input
                # We need to set a dummy target for the loc to expand - It will not be used by the decoder
                feedable_item['decoder_inputs'] = [GO_ID] + confirmed_tokens + [loc_candidates[0]]
                feedable_item['decoder_lengths'] = len(confirmed_tokens) + 1

                # Storing
                feedable_items[loc] = feedable_item
                candidates[loc] = loc_candidates

            # Running all the feedable items
            queue_name = 'policy_expand'
            fetches['%s/locs' % fetch_prefix] = CompletedFuture([loc for loc in feedable_items])
            fetches['%s/candidates' % fetch_prefix] = CompletedFuture(candidates)
            fetches['%s/results' % fetch_prefix] = [self.feedable_dataset.get_results(queue_name, item, **kwargs)
                                                    for item in feedable_items.values()]

            # Prefetching - We only return the fetches
            if is_prefetching:
                return fetches

            # Otherwise, we yield on the fetches
            fetches = yield process_fetches_dict(self.feedable_dataset, fetches)

        # Processing fetches
        def softmax(softmax_logits):
            """ Compute softmax values for the logits """
            e_x = np.exp(softmax_logits - softmax_logits.max(axis=-1, keepdims=True))
            return e_x / e_x.sum(axis=-1, keepdims=True)

        feedable_locs = fetches['%s/locs' % fetch_prefix]
        candidates = fetches['%s/candidates' % fetch_prefix]
        results = fetches['%s/results' % fetch_prefix]
        (logits, ) = zip(*results)

        # Computing probabilities
        expand_results = {loc: [] for loc in locs}
        for loc_ix, loc in enumerate(feedable_locs):
            loc_cond_probs = softmax(logits[loc_ix][-1][:len(candidates[loc])])

            # iterate over all candidate
            for candidate_ix, probability in enumerate(loc_cond_probs):
                token = candidates[loc][candidate_ix]

                # ignore PAD_ID
                if token <= EOS_ID:
                    continue

                expand_results[loc] += [OrderProbTokenLogProbs(order=ix_to_order(token),
                                                               probability=probability,
                                                               log_probs=[np.log(np.maximum(probability, 1e-8))])]

            # Sorting loc by probability
            expand_results[loc] = list(sorted(expand_results[loc], key=lambda item: item.probability, reverse=True))

        # Returning
        return expand_results