Beispiel #1
0
    def _get_detailed_results(decoded_results, feed_dict, evaluation_loop_ix):
        """ Computes detailed accuracy statistics for the batch
            :param decoded_results: The decoded results (output of _decode() function)
            :param feed_dict: The feed dictionary that was given to session.run()
            :param eval_loop_ix: The current evaluation loop index
            :return: An ordered dictionary with result_name as key and a list of result values (Detailed results)
        """
        del feed_dict  # Unused args

        targets = decoded_results['targets']
        log_probs = decoded_results['log_probs']
        request_ids = decoded_results['request_id']
        batch_size = targets.shape[0]
        nb_locs_per_target = targets.shape[1]
        decoded_orders = decoded_results['decoded_orders']

        # Extracting from additional info
        for field_name in [
                'current_power', 'current_season', 'in_retreat_phase'
        ]:
            if field_name not in decoded_results:
                LOGGER.warning(
                    'The field "%s" is missing. Cannot compute stats',
                    field_name)
                return OrderedDict()
        current_power_name = [
            POWER_VOCABULARY_IX_TO_KEY[current_power]
            for current_power in decoded_results['current_power']
        ]
        current_season_name = [
            'SFW'[current_season]
            for current_season in decoded_results['current_season']
        ]
        in_retreat_phase = decoded_results['in_retreat_phase']

        # Prefix
        prefix = '[TF]' if evaluation_loop_ix == 0 else '[Gr]'

        # Building results dict
        results = OrderedDict()
        results[prefix + 'Accuracy'] = []
        results[prefix + 'LogProbsDetails'] = [
            {}
        ]  # {request_id: (log_probs, mismatch)}
        for power_name in POWER_VOCABULARY_LIST:
            results[prefix + power_name] = []
        for order_type in [
                'H', '-', '- VIA', 'S', 'C', 'R', 'B', 'D', 'WAIVE'
        ]:
            results[prefix + 'Order %s' % order_type] = []
        for season in 'SFW':  # Spring, Fall, Winter
            results[prefix + 'Season %s' % season] = []
        for phase in 'MRA':  # Movement, Retreats, Adjustments
            results[prefix + 'Phase %s' % phase] = []
        for position in range(-1, NB_SUPPLY_CENTERS
                              ):  # Position -1 is used for Adjustment phases
            results[prefix + 'Position %d' % position] = []
        for order_loc in sorted(STANDARD_TOPO_LOCS):  # Order location
            results[prefix + 'Loc %s' % order_loc] = []

        # Computing accuracy
        for batch_ix in range(batch_size):
            request_id = request_ids[batch_ix]
            player_orders_mismatch = False
            nb_waive = 0

            # We didn't learn a policy - Skipping
            if not len(targets[batch_ix]) or targets[batch_ix][0] == 0:  # pylint: disable=len-as-condition
                continue

            for loc_ix in range(nb_locs_per_target):
                decoded_target = targets[batch_ix][loc_ix]
                decoded_target_order = ix_to_order(
                    decoded_target) if decoded_target > EOS_ID else ''
                if not decoded_target_order:
                    break

                if decoded_target_order == 'WAIVE':
                    loc = 'WAIVE_{}'.format(nb_waive)
                    order_type = 'WAIVE'
                    nb_waive += 1
                else:
                    loc = decoded_target_order.split()[1][:3]
                    order_type = decoded_target_order.split()[2] if len(
                        decoded_target_order.split()) > 2 else 'H'
                    if order_type == '-' and decoded_target_order.split(
                    )[-1] == 'VIA':
                        order_type = '- VIA'

                # Determining categories
                power_name = current_power_name[batch_ix]
                season = current_season_name[batch_ix]
                if in_retreat_phase[batch_ix]:
                    phase = 'R'
                    order_type = 'R' if order_type in ['-', '- VIA'
                                                       ] else order_type
                else:
                    phase = {
                        'H': 'M',
                        '-': 'M',
                        '- VIA': 'M',
                        'S': 'M',
                        'C': 'M',
                        'R': 'R',
                        'D': 'A',
                        'B': 'A',
                        'WAIVE': 'A'
                    }[order_type]

                # Use -1 as position for A phase
                position = -1 if phase == 'A' else loc_ix
                stats_key = StatsKey(prefix, power_name, order_type, season,
                                     phase, position)

                # Computing accuracies
                success = int(loc in decoded_orders[batch_ix]
                              and decoded_orders[batch_ix][loc].order
                              == decoded_target_order)
                if not success:
                    player_orders_mismatch = True

                results[prefix + 'Accuracy'] += [success]
                results[prefix + power_name] += [success]
                results[prefix + 'Order %s' % order_type] += [success]
                results[prefix + 'Season %s' % season] += [success]
                results[prefix + 'Phase %s' % phase] += [success]
                results[prefix + 'Position %d' % position] += [success]
                if order_type != 'WAIVE':
                    results[prefix + 'Loc %s' % loc] += [success]
                results[stats_key] = results.get(stats_key, []) + [success]

            # Storing (log_probs, mismatch)
            results[prefix + 'LogProbsDetails'][0][request_id] = (
                log_probs[batch_ix].sum(), int(player_orders_mismatch))

        # Returning results
        return results
Beispiel #2
0
    def _evaluate(self, decoded_results, feed_dict, eval_loop_ix,
                  incl_detailed):
        """ Calculates the accuracy of the model
            :param decoded_results: The decoded results (output of _decode() function)
            :param feed_dict: The feed dictionary that was given to session.run()
            :param eval_loop_ix: The current evaluation loop index (-1 for training)
            :param incl_detailed: is true if training is over, more statistics can be computed
            :return: A tuple consisting of:
                        1) An ordered dictionary with result_name as key and (weight, value) as value  (Regular results)
                        2) An ordered dictionary with result_name as key and a list of result values  (Detailed results)
        """
        # Detecting if it's our evaluation or not
        if eval_loop_ix == -1:
            eval_loop_ix = 0
        else:
            our_validation = eval_loop_ix in self.my_eval_loop_ixs
            if not our_validation:
                return OrderedDict(), OrderedDict()
            eval_loop_ix = self.my_eval_loop_ixs.index(eval_loop_ix)

        # Evaluating
        policy_loss = decoded_results[
            'policy_loss']  # Avg X-Ent per unit-order
        perplexity = math.exp(policy_loss) if policy_loss <= 100 else float(
            'inf')
        targets = decoded_results['targets']
        batch_size = targets.shape[0]
        nb_locs_per_target = targets.shape[1]
        decoded_orders = decoded_results['decoded_orders']

        # Logging an error if perplexity is inf
        if perplexity == float('inf'):
            for request_id, log_probs in zip(decoded_results['request_id'],
                                             decoded_results['log_probs']):
                if sum(log_probs) <= -100:
                    LOGGER.error(
                        'Request %s has log probs that causes a -inf perplexity.',
                        request_id)

        # Accuracy
        acc_1_num, denom = 0., 0.
        acc_1_no_hold_num, denom_no_hold = 0., 0.
        nb_tokens_match, nb_tokens_total = 0., 0.
        acc_player_num, denom_player = 0., 0.

        # Decoding batch by batch, loc by loc
        for batch_ix in range(batch_size):
            player_order_mismatch = False
            nb_waive = 0

            # We didn't learn a policy - Skipping
            if not len(targets[batch_ix]) or targets[batch_ix][0] == 0:  # pylint: disable=len-as-condition
                continue

            for loc_ix in range(nb_locs_per_target):
                decoded_target = targets[batch_ix][loc_ix]
                decoded_target_order = ix_to_order(
                    decoded_target) if decoded_target > EOS_ID else ''
                if not decoded_target_order:
                    break
                nb_tokens_total += TOKENS_PER_ORDER

                if decoded_target_order == 'WAIVE':
                    loc = 'WAIVE_{}'.format(nb_waive)
                    is_hold_order = False
                    nb_waive += 1
                else:
                    loc = decoded_target_order.split()[1][:3]
                    is_hold_order = len(decoded_target_order.split(
                    )) <= 2 or decoded_target_order.split()[2] == 'H'

                # Computing Acc 1
                denom += 1.
                if not is_hold_order:
                    denom_no_hold += 1.

                # Checking if the target is in the decoded results
                if loc in decoded_orders[batch_ix] and decoded_orders[
                        batch_ix][loc].order == decoded_target_order:
                    acc_1_num += 1.
                    if not is_hold_order:
                        acc_1_no_hold_num += 1.
                else:
                    player_order_mismatch = True

                # Computing Acc Tokens
                tokenized_targets = get_order_tokens(decoded_target_order) + [
                    EOS_TOKEN
                ]
                tokenized_targets += [PAD_TOKEN] * (TOKENS_PER_ORDER -
                                                    len(tokenized_targets))

                tokenized_results = [-1] * TOKENS_PER_ORDER
                if loc in decoded_orders[batch_ix]:
                    tokenized_results = get_order_tokens(
                        decoded_orders[batch_ix][loc].order) + [EOS_TOKEN]
                    tokenized_results += [PAD_TOKEN] * (TOKENS_PER_ORDER -
                                                        len(tokenized_results))

                nb_tokens_match += sum([
                    1. for i in range(TOKENS_PER_ORDER)
                    if tokenized_targets[i] == tokenized_results[i]
                ])

            # Compute accuracy for this phase
            if not player_order_mismatch:
                acc_player_num += 1
            denom_player += 1

        # No orders at all
        if not denom:
            acc_1 = 1.
            acc_1_no_hold = 1.
            acc_tokens = 1.
            acc_player = 1.
        else:
            acc_1 = acc_1_num / (denom + 1e-12)
            acc_1_no_hold = acc_1_no_hold_num / (denom_no_hold + 1e-12)
            acc_tokens = nb_tokens_match / (nb_tokens_total + 1e-12)
            acc_player = acc_player_num / (denom_player + 1e-12)

        # Computing detailed statistics
        detailed_results = OrderedDict()
        if incl_detailed:
            detailed_results = self._get_detailed_results(
                decoded_results, feed_dict, eval_loop_ix)

        # Validating decoder type
        decoder_type = [
            value for tensor, value in feed_dict.items()
            if 'decoder_type' in tensor.name
        ]
        decoder_type = '' if not decoder_type else decoder_type[0][0]

        # 0 - Teacher Forcing results
        if eval_loop_ix == 0:
            assert decoder_type == TRAINING_DECODER
            return OrderedDict({
                '[TF]X-Ent': (denom, policy_loss),
                '[TF]Perplexity': (denom, perplexity),
                '[TF]Acc_1': (denom, 100. * acc_1),
                '[TF]Acc_1_NoHold': (denom_no_hold, 100. * acc_1_no_hold),
                '[TF]Acc_Tokens': (nb_tokens_total, 100. * acc_tokens),
                '[TF]Acc_Player': (denom_player, 100. * acc_player)
            }), detailed_results

        # 1 - Greedy Results
        if eval_loop_ix == 1:
            assert decoder_type == GREEDY_DECODER
            return OrderedDict({
                '[Gr]Acc_1': (denom, 100. * acc_1),
                '[Gr]Acc_1_NoHold': (denom_no_hold, 100. * acc_1_no_hold),
                '[Gr]Acc_Tokens': (nb_tokens_total, 100. * acc_tokens),
                '[Gr]Acc_Player': (denom_player, 100. * acc_player)
            }), detailed_results

        # Otherwise, invalid evaluation_loop_ix
        raise RuntimeError('Invalid evaluation_loop_ix - Got "%s"' %
                           eval_loop_ix)
Beispiel #3
0
    def expand(self, confirmed_orders, locs, state_proto, power_name, phase_history_proto, possible_orders_proto,
               **kwargs):
        """ Computes the conditional probability of possible orders for each loc given the confirmed orders.
            :param confirmed_orders: The list of orders on which to condition the probs (e.g. ['A PAR H', 'A MAR - SPA']
            :param locs: The locations for which we want to compute probabilities
            :param state_proto: A `.proto.game.State` representation of the state of the game.
            :param power_name: The power name for which we want the probabilities
            :param phase_history_proto: A list of `.proto.game.PhaseHistory`. This represents prev phases.
            :param possible_orders_proto: A `proto.game.PossibleOrders` object representing possible order for each loc.
            :param kwargs: Additional optional kwargs:
                - player_seed: The seed to apply to the player to compute a deterministic mask.
                - noise: The sigma of the additional noise to apply to the intermediate layers (i.e. sigma * epsilon)
                - temperature: The temperature to apply to the logits. (Default to 0. for deterministic/greedy)
                - dropout_rate: The amount of dropout to apply to the inputs/outputs of the decoder.
                - retry_on_failure: Boolean that indicates to retry querying from the model if an error is encountered.
                - prefetch: Boolean that indicates to return a dictionary of fetches (str: PrefetchedItem/Future)
                - fetches: Dictionary of (str: future_results) that was computed with prefetch=True
            :return:
                - if prefetch=True, a dictionary of fetches (key as string, value is a future (or list) to yield on)
                - if prefetch=False,
                    A dictionary with every location in locs as key, and a list of tuples where each tuple is composed
                     of 1) an order, 2) the order conditional probability, 3) the conditional log probs of each token
                        e.g. {'PAR': [('A PAR H', 0.000, [...]), ('A PAR - BUR', 0.000, [...]), ...]}
        """
        # Determining if we need to prefetch or postfetch
        fetches = kwargs.get('fetches', {})
        is_prefetching = kwargs.get('prefetch', False)
        is_postfetching = fetches and not is_prefetching
        fetch_prefix = 'expand'

        # Locations
        locs = [loc[:3] for loc in locs]
        confirmed_locs = [order.split()[1][:3] for order in confirmed_orders]

        # Getting fetches
        if not is_postfetching:

            confirmed_tokens = []
            for order in confirmed_orders:
                confirmed_tokens += self.tokenize(order)

            # Building all feedable items
            feedable_items = OrderedDict()
            candidates = OrderedDict()
            for loc in locs:
                feedable_item = self.feedable_dataset.get_feedable_item(confirmed_locs + [loc],
                                                                        state_proto,
                                                                        power_name,
                                                                        phase_history_proto,
                                                                        possible_orders_proto,
                                                                        **kwargs)
                loc_candidates = [candidate for candidate in feedable_item['candidates'][-1] if candidate > PAD_ID]

                # No candidates - Can't expand
                if not loc_candidates:
                    continue

                # Setting the decoder input
                # We need to set a dummy target for the loc to expand - It will not be used by the decoder
                feedable_item['decoder_inputs'] = [GO_ID] + confirmed_tokens + [loc_candidates[0]]
                feedable_item['decoder_lengths'] = len(confirmed_tokens) + 1

                # Storing
                feedable_items[loc] = feedable_item
                candidates[loc] = loc_candidates

            # Running all the feedable items
            queue_name = 'policy_expand'
            fetches['%s/locs' % fetch_prefix] = CompletedFuture([loc for loc in feedable_items])
            fetches['%s/candidates' % fetch_prefix] = CompletedFuture(candidates)
            fetches['%s/results' % fetch_prefix] = [self.feedable_dataset.get_results(queue_name, item, **kwargs)
                                                    for item in feedable_items.values()]

            # Prefetching - We only return the fetches
            if is_prefetching:
                return fetches

            # Otherwise, we yield on the fetches
            fetches = yield process_fetches_dict(self.feedable_dataset, fetches)

        # Processing fetches
        def softmax(softmax_logits):
            """ Compute softmax values for the logits """
            e_x = np.exp(softmax_logits - softmax_logits.max(axis=-1, keepdims=True))
            return e_x / e_x.sum(axis=-1, keepdims=True)

        feedable_locs = fetches['%s/locs' % fetch_prefix]
        candidates = fetches['%s/candidates' % fetch_prefix]
        results = fetches['%s/results' % fetch_prefix]
        (logits, ) = zip(*results)

        # Computing probabilities
        expand_results = {loc: [] for loc in locs}
        for loc_ix, loc in enumerate(feedable_locs):
            loc_cond_probs = softmax(logits[loc_ix][-1][:len(candidates[loc])])

            # iterate over all candidate
            for candidate_ix, probability in enumerate(loc_cond_probs):
                token = candidates[loc][candidate_ix]

                # ignore PAD_ID
                if token <= EOS_ID:
                    continue

                expand_results[loc] += [OrderProbTokenLogProbs(order=ix_to_order(token),
                                                               probability=probability,
                                                               log_probs=[np.log(np.maximum(probability, 1e-8))])]

            # Sorting loc by probability
            expand_results[loc] = list(sorted(expand_results[loc], key=lambda item: item.probability, reverse=True))

        # Returning
        return expand_results
Beispiel #4
0
    def _decode(**fetches):
        """ Performs decoding on the output (order_based model)
            :param fetches: A dictionary of fetches from the model.

            Keys can include:

            - selected_tokens / argmax_tokens: [Required] The tokens from the model (Tensor [batch, decoder_length])
            - log_probs: [Required] The log probs from the model (Tensor [batch, decoder_length])
            - policy_loss: The policy loss for the batch.
            - targets: The targets from the model (Tensor [batch, length]). Required for evaluation.
            - current_power: The current_power from the model (Tensor [batch,]). Required for evaluation.
            - current_season: The current_season from the model (Tensor [batch,]). Required for evaluation.
            - in_retreat_phase: Boolean that indicates dislodged units are on the map. ([b,]). Required for evaluation.
            - request_id: The unique request id for each item in the batch.

            :return: A dictionary of decoded results, including
                - 1) decoded_orders:
                    A list of dictionary (one per batch) where each dict has location as key and a
                      OrderProbTokenLogProbs tuple as value (i.e. an order, its prob, and the token log probs)
                        e.g. [{'PAR': (order, prob, log_probs),'MAR': (order, prob, log_probs)},
                              {'PAR': (order, prob, log_probs),'MAR': (order, prob, log_probs)}]
                - 2) various other keys for evaluation
        """
        # Missing the required fetches, returning an empty decoded results
        if ('selected_tokens' not in fetches and 'argmax_tokens'
                not in fetches) or 'log_probs' not in fetches:
            return {}

        # tokens:           [batch, dec_len]
        # log_probs:        [batch, dec_len]
        # policy_loss:      ()
        # targets:          [batch, dec_len]
        # current_power:    [batch]
        # current_season:   [batch]
        # in_retreat_phase: [batch]
        # request_ids:      [batch]
        tokens = fetches.get('selected_tokens', fetches.get('argmax_tokens'))
        log_probs = fetches['log_probs']
        policy_loss = fetches.get('policy_loss', None)
        targets = fetches.get('targets', None)
        current_power = fetches.get('current_power', None)
        current_season = fetches.get('current_season', None)
        in_retreat_phase = fetches.get('in_retreat_phase', None)
        request_ids = fetches.get('request_id', None)

        # Decoding orders
        results = []
        result_tokens = []
        nb_batches = tokens.shape[0]

        for batch_ix in range(nb_batches):
            batch_results = OrderedDict()
            batch_results_tokens = OrderedDict()
            batch_tokens = tokens[batch_ix]
            batch_log_probs = log_probs[batch_ix]
            nb_waive = 0

            # We didn't try to predict orders - Skipping
            if not len(batch_tokens) or batch_tokens[0] == [0]:  # pylint: disable=len-as-condition
                results += [batch_results]
                result_tokens += [batch_results_tokens]
                continue

            for token_ix, token in enumerate(batch_tokens):
                if token <= EOS_ID:
                    continue

                order = ix_to_order(token)

                # WAIVE orders
                if order == 'WAIVE':
                    loc = 'WAIVE_{}'.format(nb_waive)
                    nb_waive += 1

                # Use normal location and skip if already stored
                else:
                    loc = order.split()[1]
                    if loc in batch_results:
                        continue
                    loc = loc[:3]

                # Storing order
                batch_results[loc] = OrderProbTokenLogProbs(
                    order=order,
                    probability=1.,
                    log_probs=[batch_log_probs[token_ix]])
                batch_results_tokens[loc] = [token]

            # Done with batch
            results += [batch_results]
            result_tokens += [batch_results_tokens]

        # Returning
        return {
            'decoded_orders': results,
            'policy_loss': policy_loss,
            'targets': targets,
            'tokens': result_tokens,
            'current_power': current_power,
            'current_season': current_season,
            'in_retreat_phase': in_retreat_phase,
            'request_id': request_ids,
            'log_probs': log_probs
        }
Beispiel #5
0
    def get_beam_orders(self, locs, state_proto, power_name, phase_history_proto, possible_orders_proto, **kwargs):
        """ Finds all the beams with their probabilities returned by the diverse beam search
            Beams are ordered by score (highest first).
            :param locs: A list of locations for which we want orders
            :param state_proto: A `.proto.game.State` representation of the state of the game.
            :param power_name: The power name for which we want the orders and the state values
            :param phase_history_proto: A list of `.proto.game.PhaseHistory`. This represents prev phases.
            :param possible_orders_proto: A `proto.game.PossibleOrders` object representing possible order for each loc.
            :param kwargs: Additional optional kwargs:
                - player_seed: The seed to apply to the player to compute a deterministic mask.
                - noise: The sigma of the additional noise to apply to the intermediate layers (i.e. sigma * epsilon)
                - temperature: The temperature to apply to the logits. (Default to 0. for deterministic/greedy)
                - dropout_rate: The amount of dropout to apply to the inputs/outputs of the decoder.
                - with_state_value: Boolean that indicates to also query the value function.
                - retry_on_failure: Boolean that indicates to retry querying from the model if an error is encountered.
                - prefetch: Boolean that indicates to return a dictionary of fetches (str: PrefetchedItem/Future)
                - fetches: Dictionary of (str: future_results) that was computed with prefetch=True
            :return:
                - if prefetch=True, a dictionary of fetches (key as string, value is a future (or list) to yield on)
                - if prefetch=False and with_state_value=False (default), a tuple consisting of:
                     1) A list of beams (i.e. a list of selected orders for each beam)
                     2) A list of probability (the probability of selecting each beam)
                - if prefetch=False and with_state_value=True, a tuple consisting of:
                     1) A list of beams (i.e. a list of selected orders for each beam)
                     2) A list of probability (the probability of selecting each beam)
                     3) The state value for the given state
        """
        # Determining if we need to prefetch or postfetch
        fetches = kwargs.get('fetches', {})
        is_prefetching = kwargs.get('prefetch', False)
        is_postfetching = fetches and not is_prefetching
        fetch_prefix = 'get_orders'
        with_state_value = kwargs.get('with_state_value', False)

        # Getting fetches
        if not is_postfetching:
            locs = [loc[:3] for loc in locs]

            # Running policy model
            fetches['%s/decode_fetches' % fetch_prefix] = self._decode_policy(locs,
                                                                              state_proto,
                                                                              power_name,
                                                                              phase_history_proto,
                                                                              possible_orders_proto,
                                                                              use_beam=True,
                                                                              **strip_keys(kwargs, ['use_beam']))
            # Prefetching - We only return the fetches
            if is_prefetching:
                return fetches

            # Otherwise, we yield on the fetches
            fetches = yield process_fetches_dict(self.feedable_dataset, fetches)

        # Variables
        beams, adj_probs, state_value = [], [], 0.

        # Processing
        decode_fetches = fetches['%s/decode_fetches' % fetch_prefix]        # (beam_orders, beam_log_probs, draw, value)
        if decode_fetches is None:
            return tuple([beams, adj_probs] + ([state_value] if with_state_value else []))

        # Computing adj probabilities
        beam_orders, beam_log_probs = decode_fetches[:2]
        probs = np.exp(beam_log_probs - logsumexp(beam_log_probs))
        adj_probs = apply_temperature(probs, temperature=1.).tolist()

        # Decoding
        for beam_candidates in beam_orders:
            beams += [[ix_to_order(order_ix) for order_ix in beam_candidates if order_ix > EOS_ID]]

        # Getting state value
        if with_state_value:
            state_value = decode_fetches[-1]

        # Returning
        return tuple([beams, adj_probs] + ([state_value] if with_state_value else []))