コード例 #1
ファイル: adapter.py プロジェクト: zhanpengfang/research
    def _decode_policy(self, locs, state_proto, power_name,
                       phase_history_proto, possible_orders_proto, **kwargs):
        """ Returns the output of the Policy Model decoder
            :param locs: A list of locations for which we want orders
            :param state_proto: A `.proto.game.State` representation of the state of the game.
            :param power_name: The power name for which we want the orders and the state values
            :param phase_history_proto: A list of `.proto.game.PhaseHistory`. This represents prev phases.
            :param possible_orders_proto: A `proto.game.PossibleOrders` object representing possible order for each loc.
            :param kwargs: Additional optional kwargs:
                - player_seed: The seed to apply to the player to compute a deterministic mask.
                - noise: The sigma of the additional noise to apply to the intermediate layers (i.e. sigma * epsilon)
                - temperature: The temperature to apply to the logits. (Default to 0. for deterministic/greedy)
                - dropout_rate: The amount of dropout to apply to the inputs/outputs of the decoder.
                - with_state_value: Boolean that indicates to also query the value function.
                - use_beam: Boolean that indicates that we want to use a beam search,
                - retry_on_failure: Boolean that indicates to retry querying from the model if an error is encountered.
                - prefetch: Boolean that indicates to return a dictionary of fetches (str: PrefetchedItem/Future)
                - fetches: Dictionary of (str: future_results) that was computed with prefetch=True
            :return: A future (fetches) to yield on.
        is_prefetching = kwargs.get('prefetch', False)

        # No locations provided, we can return early
        if not locs:
            ret_val = None
            return CompletedFuture(ret_val) if is_prefetching else ret_val

        # Getting feedable item
        feedable_item = self.feedable_dataset.get_feedable_item(
            locs, state_proto, power_name, phase_history_proto,
            possible_orders_proto, **kwargs)
        if not feedable_item:
                'The method .get_feedable_item() did not return an item to feed to the model.'
                'Make sure you have provided the correct locs and a list of possible orders'
            ret_val = None
            return CompletedFuture(ret_val) if is_prefetching else ret_val

        # Queue
        with_state_value = kwargs.get('with_state_value', False)
        use_beam = kwargs.get('use_beam', False)
        queue_name = {
            (False, False): 'policy_evaluate',
            (False, True): 'policy_evaluate_with_state_value',
            (True, False): 'policy_beam_search',
            (True, True): 'policy_beam_search_with_state_value'
        }[(use_beam, with_state_value)]
        return self.feedable_dataset.get_results(queue_name, feedable_item,
    def get_state_value(self,
        """ Computes the value of the current state for a given power
            :param state_proto: A `.proto.game.State` representation of the state of the game.
            :param power_name: The power name for which we want to retrieve the value
            :param phase_history_proto: A list of `.proto.game.PhaseHistory`. This represents prev phases.
            :param possible_orders_proto: A `proto.game.PossibleOrders` object representing possible order for each loc.
            :param kwargs: Additional optional kwargs:
                - player_seed: The seed to apply to the player to compute a deterministic mask.
                - noise: The sigma of the additional noise to apply to the intermediate layers (i.e. sigma * epsilon)
                - temperature: The temperature to apply to the logits. (Default to 0. for deterministic/greedy)
                - dropout_rate: The amount of dropout to apply to the inputs/outputs of the decoder.
                - retry_on_failure: Boolean that indicates to retry querying from the model if an error is encountered.
                - prefetch: Boolean that indicates to return a dictionary of fetches (str: PrefetchedItem/Future)
                - fetches: Dictionary of (str: future_results) that was computed with prefetch=True
                - if prefetch=True, a dictionary of fetches (key as string, value is a future (or list) to yield on)
                - if prefetch=False, a float representing the value of the state of the game to the specified power
        # Determining if we need to prefetch or postfetch
        fetches = kwargs.get('fetches', {})
        is_prefetching = kwargs.get('prefetch', False)
        is_postfetching = fetches and not is_prefetching
        fetch_prefix = 'get_state_value'

        # Getting fetches
        if not is_postfetching:

            if not self.has_value_model:
                    'This model does not have a value function. Returning a value of 0.'
                return {
                    '%s/ret_val' % fetch_prefix: CompletedFuture(0.)
                } if is_prefetching else 0.

            # Finding orderable locations
            locs, _ = get_orderable_locs_for_powers(state_proto, [power_name])

            # Building a list of empty possible orders
            # The value function should at most only use the first step of the decoder (time step 0)
            # So, we don't need to apply a mask
            if possible_orders_proto is None:
                possible_orders_proto = MapStringList().value  # pylint: disable=no-member
                for loc in locs:

            # Getting feedable item
            feedable_item = self.feedable_dataset.get_feedable_item(
                locs, state_proto, power_name, phase_history_proto,
                possible_orders_proto, **kwargs)
            if not feedable_item:
                    'The method .get_feedable_item() did not return an item to feed to the model.'
                    'Make sure you have provided the correct locs and a list of possible orders'
                LOGGER.warning('Returning a value of 0.')
                return {
                    '%s/ret_val' % fetch_prefix: CompletedFuture(0.)
                } if is_prefetching else 0.

            # Selecting queue
            queue_name = 'policy_get_value'
            fetches['%s/state_value' %
                    fetch_prefix] = self.feedable_dataset.get_results(
                        queue_name, feedable_item, **kwargs)
            # Prefetching - We only return the fetches
            if is_prefetching:
                return fetches

            # Otherwise, we yield on the fetches
            fetches = yield process_fetches_dict(self.feedable_dataset,

        # Processing fetches
        if '%s/ret_val' % fetch_prefix in fetches:
            return fetches['%s/ret_val' % fetch_prefix]

        # Returning the fetched state value
        (state_value, ) = fetches['%s/state_value' % fetch_prefix]
        return state_value
    def expand(self, confirmed_orders, locs, state_proto, power_name,
               phase_history_proto, possible_orders_proto, **kwargs):
        """ Computes the conditional probability of possible orders for each loc given the confirmed orders.
            :param confirmed_orders: The list of orders on which to condition the probs (e.g. ['A PAR H', 'A MAR - SPA']
            :param locs: The locations for which we want to compute probabilities
            :param state_proto: A `.proto.game.State` representation of the state of the game.
            :param power_name: The power name for which we want the probabilities
            :param phase_history_proto: A list of `.proto.game.PhaseHistory`. This represents prev phases.
            :param possible_orders_proto: A `proto.game.PossibleOrders` object representing possible order for each loc.
            :param kwargs: Additional optional kwargs:
                - player_seed: The seed to apply to the player to compute a deterministic mask.
                - noise: The sigma of the additional noise to apply to the intermediate layers (i.e. sigma * epsilon)
                - temperature: The temperature to apply to the logits. (Default to 0. for deterministic/greedy)
                - dropout_rate: The amount of dropout to apply to the inputs/outputs of the decoder.
                - retry_on_failure: Boolean that indicates to retry querying from the model if an error is encountered.
                - prefetch: Boolean that indicates to return a dictionary of fetches (str: PrefetchedItem/Future)
                - fetches: Dictionary of (str: future_results) that was computed with prefetch=True
                - if prefetch=True, a dictionary of fetches (key as string, value is a future (or list) to yield on)
                - if prefetch=False,
                    A dictionary with every location in locs as key, and a list of tuples where each tuple is composed
                     of 1) an order, 2) the order conditional probability, 3) the conditional log probs of each token
                        e.g. {'PAR': [('A PAR H', 0.000, [...]), ('A PAR - BUR', 0.000, [...]), ...]}
        # pylint: disable=too-many-nested-blocks
        # Determining if we need to prefetch or postfetch
        fetches = kwargs.get('fetches', {})
        is_prefetching = kwargs.get('prefetch', False)
        is_postfetching = fetches and not is_prefetching
        fetch_prefix = 'expand'

        # Locations
        locs = [loc[:3] for loc in locs]
        confirmed_locs = [order.split()[1][:3] for order in confirmed_orders]

        # Getting fetches
        if not is_postfetching:

            confirmed_tokens = []
            for order in confirmed_orders:
                confirmed_tokens += self.tokenize(order)

            # Building a list of conditional probs we want to expand
            # e.g. A PAR H, A PAR - MAR, A PAR - BUR
            # we want P('H' | 'A PAR'), P('MAR', 'A PAR -'), ...

            # 1) We compute all the prefix (RHS of the prob) and we only expand where count is > 1
            prefix_count = {}  # {prefix => count}
            for loc in locs:
                for order in possible_orders_proto[loc].value:
                    tokens = [-1 + -1 * locs.index(loc)] + self.tokenize(
                        order)  # e.g. [-2, 25, 4, 25, ...]
                    nb_tokens = len(tokens)
                    for token_ix in range(1, nb_tokens):
                        prefix = tuple(tokens[:token_ix])
                        prefix_count[prefix] = prefix_count.get(prefix, 0) + 1

            # 2) Building the list of feedable items only for probs we need to expand
            feedable_items = OrderedDict()  # {prefix => feedable_item}
            items_to_expand = OrderedDict(
            )  # {prefix => set of next available token}
            for loc in locs:  # pylint: disable=too-many-nested-blocks
                for order in possible_orders_proto[loc].value:
                    tokens = [-1 + -1 * locs.index(loc)] + self.tokenize(
                        order)  # e.g. [-2, 25, 4, 25, ...]
                    nb_tokens = len(tokens)

                    for token_ix in range(1, nb_tokens):
                        prefix = tuple(tokens[:token_ix])
                        if prefix_count.get(
                                prefix) > 1 and prefix not in feedable_items:
                            feedable_item = self.feedable_dataset.get_feedable_item(
                                confirmed_locs + [loc], state_proto,
                                power_name, phase_history_proto,
                                possible_orders_proto, **kwargs)

                            # The teacher forcing is GO_ID, the confirmed orders prefix, the actual prefix, and a dummy
                            feedable_item['decoder_inputs'] = [
                            ] + confirmed_tokens + list(prefix[1:]) + [PAD_ID]
                            feedable_item['decoder_lengths'] = len(
                                confirmed_tokens) + len(prefix[1:]) + 1

                            # Keeping a list of orders using the prefix
                            for possible_order in possible_orders_proto[
                                new_tokens = [-1 + -1 * locs.index(loc)
                                              ] + self.tokenize(possible_order)
                                if prefix == tuple(new_tokens[:token_ix]):
                                    items_to_expand.setdefault(prefix, set())

                            # Storing feedable item
                            feedable_items[prefix] = feedable_item

            # 3) Removing items_to_expand with only 1 items
            # We know for sure the probability will be 100%
            for prefix in list(items_to_expand.keys()):
                if len(items_to_expand[prefix]) == 1:
                    del items_to_expand[prefix]
                    del feedable_items[prefix]

            # 4) Running all the feedable items
            queue_name = 'policy_expand'
            fetches['%s/items_to_expand' %
                    fetch_prefix] = CompletedFuture(items_to_expand)
            fetches['%s/results' % fetch_prefix] = [
                self.feedable_dataset.get_results(queue_name, item, **kwargs)
                for item in feedable_items.values()

            # Prefetching - We only return the fetches
            if is_prefetching:
                return fetches

            # Otherwise, we yield on the fetches
            fetches = yield process_fetches_dict(self.feedable_dataset,

        # Processing fetches
        def softmax(softmax_logits):
            """ Compute softmax values for the logits """
            e_x = np.exp(softmax_logits -
                         softmax_logits.max(axis=-1, keepdims=True))
            return e_x / e_x.sum(axis=-1, keepdims=True)

        items_to_expand = fetches['%s/items_to_expand' % fetch_prefix]
        results = fetches['%s/results' % fetch_prefix]
        (logits, ) = zip(*results)

        # 5) Computing probs
        probs = {}  # {prefix: {loc: prob}}
        for prefix, logit in zip(items_to_expand.keys(), logits):

            # IndexError - Ignoring prefix
            if TOKENS_PER_ORDER * len(confirmed_locs) + len(prefix) - 1 >= len(
                    'Got %d logits, but trying to access logit at index %d. Ignoring prefix.',
                    TOKENS_PER_ORDER * len(confirmed_locs) + len(prefix) - 1)
                LOGGER.error('Prefix: %s - Confirmed locs: %s', prefix,

            tokens_to_expand = list(sorted(items_to_expand[prefix]))
            token_logits = logit[TOKENS_PER_ORDER * len(confirmed_locs) +
                                 len(prefix) - 1]

            # Only selecting the logits that we expect
            # There is currently a bug in the tokenization that could return additional tokens (Issue #331)
            masked_logits = []
            for token in tokens_to_expand:
                masked_logits += [token_logits[token]]
            token_probs = softmax(np.array(masked_logits, dtype=np.float32))

            # Computing the correct probabilities
            probs[prefix] = {}
            for token_ix, token in enumerate(tokens_to_expand):
                probs[prefix][token] = token_probs[token_ix]

        # 6) Computing the prob of each order at each location
        results = {}
        for loc in locs:
            results[loc] = []

            # Processing each possible order
            for order in possible_orders_proto[loc].value:
                tokens = [-1 + -1 * locs.index(loc)] + self.tokenize(order)
                nb_tokens = len(tokens)
                order_prob = 1.
                order_log_probs = []

                # Computing the total order probability and each token log probs
                for token_ix in range(1, nb_tokens):
                    prefix = tuple(tokens[:token_ix])
                    if prefix in probs and tokens[token_ix] in probs[prefix]:
                        order_prob *= probs[prefix][tokens[token_ix]]
                        order_log_probs += [
                        order_log_probs += [0.]

                results[loc] += [

            # Sorting loc by probability
            results[loc] = list(
                       key=lambda item: item.probability,

        # Returning
        return results
    def get_updated_policy_details(self,
        """ Computes the current policy details (locs, tokens, log_probs) under the current model
            Either one of 1) old_policy_details or 2) submitted_orders must be submitted to extract the locs and tokens

            :param state_proto: A `.proto.game.State` representation of the state of the game.
            :param power_name: The power name for which we want the orders and the state values
            :param phase_history_proto: A list of `.proto.game.PhaseHistory`. This represents prev phases.
            :param possible_orders_proto: A `proto.game.PossibleOrders` object representing possible order for each loc.
            :param old_policy_details: (Optional) Some policy details
                                        ==> {'locs', 'tokens', 'log_probs', 'draw_action', 'draw_prob'}
            :param submitted_orders: (Optional) A list of submitted orders ['A PAR - BUR', 'A MAR H']
            :param kwargs: Additional optional kwargs:
                - player_seed: The seed to apply to the player to compute a deterministic mask.
                - noise: The sigma of the additional noise to apply to the intermediate layers (i.e. sigma * epsilon)
                - temperature: The temperature to apply to the logits. (Default to 0. for deterministic/greedy)
                - dropout_rate: The amount of dropout to apply to the inputs/outputs of the decoder.
                - retry_on_failure: Boolean that indicates to retry querying from the model if an error is encountered.
                - prefetch: Boolean that indicates to return a dictionary of fetches (str: PrefetchedItem/Future)
                - fetches: Dictionary of (str: future_results) that was computed with prefetch=True
                - if prefetch=True, a dictionary of fetches (key as string, value is a future (or list) to yield on)
                - if prefetch=False, The corresponding updated policy details
                                    ==> {'locs', 'tokens', 'log_probs', 'draw_action', 'draw_prob'}
        assert self.feedable_dataset.has_queue(
            'policy_log_probs'), 'Unable to get supervised log probs'

        # Determining if we need to prefetch or postfetch
        fetches = kwargs.get('fetches', {})
        is_prefetching = kwargs.get('prefetch', False)
        is_postfetching = fetches and not is_prefetching
        fetch_prefix = 'get_updated_policy_details'

        # Setting tokens and actual locs
        if old_policy_details:
            actual_locs = old_policy_details['locs']
            tokens = old_policy_details['tokens']

        # Using submitted orders
            actual_locs, tokens = [], []
            for order in submitted_orders:
                actual_locs += [order.split()[1][:3]
                                ] if len(order.split()) >= 2 else []
                tokens += self.tokenize(order)

        # Getting fetches
        if not is_postfetching:

            if not old_policy_details and not submitted_orders:
                    'Unable to compute policy details without old policy details or submitted orders.'
                ret_val = {
                    'locs': [],
                    'tokens': [],
                    'log_probs': [],
                    'draw_action': False,
                    'draw_prob': 0.
                return {
                    '%s/ret_val' % fetch_prefix: CompletedFuture(ret_val)
                } if is_prefetching else ret_val

            # In adjustment phase, the locs are all the orderable locs
            if state_proto.name[-1] == 'A':
                locs, _ = get_orderable_locs_for_powers(
                    state_proto, [power_name])
            elif old_policy_details:
                locs = old_policy_details['locs']
                locs = [
                    order.split()[1][:3] for order in submitted_orders
                    if len(order.split()) >= 2

            # Getting feedable item
            feedable_item = self.feedable_dataset.get_feedable_item(
                locs, state_proto, power_name, phase_history_proto,
                possible_orders_proto, **kwargs)
            if not feedable_item:
                ret_val = {
                    'locs': [],
                    'tokens': [],
                    'log_probs': [],
                    'draw_action': False,
                    'draw_prob': 0.
                return {
                    '%s/ret_val' % fetch_prefix: CompletedFuture(ret_val)
                } if is_prefetching else ret_val

            feedable_item['decoder_inputs'] = [GO_ID] + tokens
            feedable_item['decoder_lengths'] = len(tokens)

            # Querying model
            queue_name = 'policy_log_probs'
            fetches['%s/log_probs_fetches' %
                    fetch_prefix] = self.feedable_dataset.get_results(
                        queue_name, feedable_item, **kwargs)

            # Prefetching - We only return the fetches
            if is_prefetching:
                return fetches

            # Otherwise, we yield on the fetches
            fetches = yield process_fetches_dict(self.feedable_dataset,

        # Processing fetches
        if '%s/ret_val' % fetch_prefix in fetches:
            return fetches['%s/ret_val' % fetch_prefix]

        new_log_probs, new_draw_prob = fetches['%s/log_probs_fetches' %
        new_log_probs = new_log_probs[:len(actual_locs) *

        # Validating
        assert submitted_orders is not None or len(new_log_probs) == len(

        # Returning
        return {
            old_policy_details['draw_action'] if old_policy_details else bool(
                new_draw_prob >= 0.5),
    def expand(self, confirmed_orders, locs, state_proto, power_name, phase_history_proto, possible_orders_proto,
        """ Computes the conditional probability of possible orders for each loc given the confirmed orders.
            :param confirmed_orders: The list of orders on which to condition the probs (e.g. ['A PAR H', 'A MAR - SPA']
            :param locs: The locations for which we want to compute probabilities
            :param state_proto: A `.proto.game.State` representation of the state of the game.
            :param power_name: The power name for which we want the probabilities
            :param phase_history_proto: A list of `.proto.game.PhaseHistory`. This represents prev phases.
            :param possible_orders_proto: A `proto.game.PossibleOrders` object representing possible order for each loc.
            :param kwargs: Additional optional kwargs:
                - player_seed: The seed to apply to the player to compute a deterministic mask.
                - noise: The sigma of the additional noise to apply to the intermediate layers (i.e. sigma * epsilon)
                - temperature: The temperature to apply to the logits. (Default to 0. for deterministic/greedy)
                - dropout_rate: The amount of dropout to apply to the inputs/outputs of the decoder.
                - retry_on_failure: Boolean that indicates to retry querying from the model if an error is encountered.
                - prefetch: Boolean that indicates to return a dictionary of fetches (str: PrefetchedItem/Future)
                - fetches: Dictionary of (str: future_results) that was computed with prefetch=True
                - if prefetch=True, a dictionary of fetches (key as string, value is a future (or list) to yield on)
                - if prefetch=False,
                    A dictionary with every location in locs as key, and a list of tuples where each tuple is composed
                     of 1) an order, 2) the order conditional probability, 3) the conditional log probs of each token
                        e.g. {'PAR': [('A PAR H', 0.000, [...]), ('A PAR - BUR', 0.000, [...]), ...]}
        # Determining if we need to prefetch or postfetch
        fetches = kwargs.get('fetches', {})
        is_prefetching = kwargs.get('prefetch', False)
        is_postfetching = fetches and not is_prefetching
        fetch_prefix = 'expand'

        # Locations
        locs = [loc[:3] for loc in locs]
        confirmed_locs = [order.split()[1][:3] for order in confirmed_orders]

        # Getting fetches
        if not is_postfetching:

            confirmed_tokens = []
            for order in confirmed_orders:
                confirmed_tokens += self.tokenize(order)

            # Building all feedable items
            feedable_items = OrderedDict()
            candidates = OrderedDict()
            for loc in locs:
                feedable_item = self.feedable_dataset.get_feedable_item(confirmed_locs + [loc],
                loc_candidates = [candidate for candidate in feedable_item['candidates'][-1] if candidate > PAD_ID]

                # No candidates - Can't expand
                if not loc_candidates:

                # Setting the decoder input
                # We need to set a dummy target for the loc to expand - It will not be used by the decoder
                feedable_item['decoder_inputs'] = [GO_ID] + confirmed_tokens + [loc_candidates[0]]
                feedable_item['decoder_lengths'] = len(confirmed_tokens) + 1

                # Storing
                feedable_items[loc] = feedable_item
                candidates[loc] = loc_candidates

            # Running all the feedable items
            queue_name = 'policy_expand'
            fetches['%s/locs' % fetch_prefix] = CompletedFuture([loc for loc in feedable_items])
            fetches['%s/candidates' % fetch_prefix] = CompletedFuture(candidates)
            fetches['%s/results' % fetch_prefix] = [self.feedable_dataset.get_results(queue_name, item, **kwargs)
                                                    for item in feedable_items.values()]

            # Prefetching - We only return the fetches
            if is_prefetching:
                return fetches

            # Otherwise, we yield on the fetches
            fetches = yield process_fetches_dict(self.feedable_dataset, fetches)

        # Processing fetches
        def softmax(softmax_logits):
            """ Compute softmax values for the logits """
            e_x = np.exp(softmax_logits - softmax_logits.max(axis=-1, keepdims=True))
            return e_x / e_x.sum(axis=-1, keepdims=True)

        feedable_locs = fetches['%s/locs' % fetch_prefix]
        candidates = fetches['%s/candidates' % fetch_prefix]
        results = fetches['%s/results' % fetch_prefix]
        (logits, ) = zip(*results)

        # Computing probabilities
        expand_results = {loc: [] for loc in locs}
        for loc_ix, loc in enumerate(feedable_locs):
            loc_cond_probs = softmax(logits[loc_ix][-1][:len(candidates[loc])])

            # iterate over all candidate
            for candidate_ix, probability in enumerate(loc_cond_probs):
                token = candidates[loc][candidate_ix]

                # ignore PAD_ID
                if token <= EOS_ID:

                expand_results[loc] += [OrderProbTokenLogProbs(order=ix_to_order(token),
                                                               log_probs=[np.log(np.maximum(probability, 1e-8))])]

            # Sorting loc by probability
            expand_results[loc] = list(sorted(expand_results[loc], key=lambda item: item.probability, reverse=True))

        # Returning
        return expand_results