def profile_model(model_path): late_game = load_late_game() model = load_diplomacy_model(model_path, map_location="cuda", eval=True) for game_name, game in [("new_game", Game()), ("late_game", late_game)]: print("\n#", game_name) inputs = FeatureEncoder().encode_inputs([game]) inputs = {k: v.to("cuda") for k, v in inputs.items()} for batch_size in B: b_inputs = { k: v.repeat((batch_size, ) + (1, ) * (len(v.shape) - 1)) for k, v in inputs.items() } with torch.no_grad(): tic = time.time() for _ in range(N): order_idxs, order_scores, cand_scores, final_scores = model( **b_inputs, temperature=1.0) toc = time.time() - tic print( f"[B={batch_size}] {toc}s / {N}, latency={1000*toc/N}ms, throughput={N*batch_size/toc}/s" )
def get_orders(self, game, power, *, temperature=None, top_p=None): if len(game.get_orderable_locations().get(power, [])) == 0: return [] temperature = temperature if temperature is not None else self.temperature top_p = top_p if top_p is not None else self.top_p inputs = FeatureEncoder().encode_inputs([game]) inputs = {k: v.to(self.device) for k, v in inputs.items()} with torch.no_grad(): order_idxs, cand_idxs, logits, final_scores = self.model( **inputs, temperature=temperature, top_p=top_p) resample_duplicate_disbands_inplace(order_idxs, cand_idxs, logits, inputs["x_possible_actions"], inputs["x_in_adj_phase"]) return decode_order_idxs(order_idxs[0, POWERS.index(power), :])
def get_values(self, game) -> np.ndarray: batch_inputs = FeatureEncoder().encode_inputs([game]) batch_est_final_scores = self.do_model_request( batch_inputs, self.rollout_temperature, self.rollout_top_p, values_only=True) return batch_est_final_scores[0]
def do_rollout( cls, *, game_json, hostport, set_orders_dict={}, temperature, top_p, max_rollout_length, batch_size=1, use_predicted_final_scores, mix_square_ratio_scoring=0, value_hostport=None, rollout_value_frac=0, ) -> Tuple[Tuple[Dict, List[Dict]], TimingCtx]: """Complete game, optionally setting orders for the current turn This method can safely be called in a subprocess Arguments: - game_json: json-formatted game string, e.g. output of to_saved_game_format(game) - hostport: string, "{host}:{port}" of model server - set_orders_dict: Dict[power, orders] to set for current turn - temperature: model softmax temperature for rollout policy - top_p: probability mass to samples from for rollout policy - max_rollout_length: return SC count after at most # steps - batch_size: rollout # of games in parallel - use_predicted_final_scores: if True, use model's value head for final SC predictions Returns a 2-tuple: - results, a 2-tuple: - set_orders_dict: Dict[power, orders] - list of Dict[power, final_score], len=batch_size - timings: a TimingCtx """ timings = TimingCtx() with timings("postman.client"): client = postman.Client(hostport) client.connect(3) if value_hostport is not None: value_client = postman.Client(value_hostport) value_client.connect(3) else: value_client = client with timings("setup"): faulthandler.register(signal.SIGUSR2) torch.set_num_threads(1) games = [ pydipcc.Game.from_json(game_json) for _ in range(batch_size) ] for i in range(len(games)): games[i].game_id += f"_{i}" est_final_scores = {} # game id -> np.array len=7 # set orders if specified for power, orders in set_orders_dict.items(): for game in games: game.set_orders(power, list(orders)) other_powers = [p for p in POWERS if p not in set_orders_dict] rollout_start_phase = games[0].current_short_phase rollout_end_phase = n_move_phases_later(rollout_start_phase, max_rollout_length) while True: if max_rollout_length == 0: # Handled separately. break # exit loop if all games are done before max_rollout_length ongoing_game_phases = [ game.current_short_phase for game in games if not game.is_game_done ] if len(ongoing_game_phases) == 0: break # step games together at the pace of the slowest game, e.g. process # games with retreat phases alone before moving on to the next move phase min_phase = min(ongoing_game_phases, key=sort_phase_key) batch_data = [] for game in games: if not game.is_game_done and game.current_short_phase == min_phase: with timings("encode.all_poss_orders"): all_possible_orders = game.get_all_possible_orders() with timings("encode.inputs"): inputs = FeatureEncoder().encode_inputs([game]) batch_data.append((game, inputs)) with timings("cat_pad"): xs: List[Tuple] = [b[1] for b in batch_data] batch_inputs = cls.cat_pad_inputs(xs) with timings("model"): if client != value_client: assert ( rollout_value_frac == 0 ), "If separate value model, you can't add in value each step (slow)" cur_client = value_client if min_phase == rollout_end_phase else client batch_orders, _, batch_est_final_scores = self.do_model_request( batch_inputs, temperature, top_p, client=cur_client) if min_phase == rollout_end_phase: with timings("score.accumulate"): for game_idx, (game, _) in enumerate(batch_data): est_final_scores[game.game_id] = np.array( batch_est_final_scores[game_idx]) # skip env step and exit loop once we've accumulated the estimated # scores for all games up to max_rollout_length break with timings("env"): assert len(batch_data) == len(batch_orders), "{} != {}".format( len(batch_data), len(batch_orders)) # set_orders and process assert len(batch_data) == len(batch_orders) for (game, _), power_orders in zip(batch_data, batch_orders): if game.is_game_done: continue power_orders = dict(zip(POWERS, power_orders)) for other_power in other_powers: game.set_orders(other_power, list(power_orders[other_power])) assert game.current_short_phase == min_phase game.process() for (game, _) in batch_data: if game.is_game_done: with timings("score.gameover"): final_scores = np.array( get_square_scores_from_game(game)) est_final_scores[game.game_id] = final_scores other_powers = POWERS # no set orders on subsequent turns # out of rollout loop if max_rollout_length > 0: assert len(est_final_scores) == len(games) else: assert ( not other_powers ), "If max_rollout_length=0 it's assumed that all orders are pre-defined." # All orders are set. Step env. Now only need to get values. game.process() batch_data = [] for game in games: if not game.is_game_done: with timings("encode.inputs"): inputs = FeatureEncoder().encode_inputs([game]) batch_data.append((game, inputs)) else: est_final_scores[ game.game_id] = get_square_scores_from_game(game) if batch_data: with timings("cat_pad"): xs: List[Tuple] = [b[1] for b in batch_data] batch_inputs = self.cat_pad_inputs(xs) with timings("model"): _, _, batch_est_final_scores = self.do_model_request( batch_inputs, temperature, top_p, client=value_client) assert batch_est_final_scores.shape[0] == len(batch_data) assert batch_est_final_scores.shape[1] == len(POWERS) for game_idx, (game, _) in enumerate(batch_data): est_final_scores[ game.game_id] = batch_est_final_scores[game_idx] with timings("final_scores"): # get GameScores objects for current game state current_game_scores = [{ p: compute_game_scores_from_state(i, game.get_state()) for i, p in enumerate(POWERS) } for game in games] # get estimated or current sum of squares scoring final_game_scores = [ dict(zip(POWERS, est_final_scores[game.game_id])) for game, current_scores in zip(games, current_game_scores) ] # mix in current sum of squares ratio to encourage losing powers to try hard if mix_square_ratio_scoring > 0: for game, final_scores, current_scores in zip( games, final_game_scores, current_game_scores): for p in POWERS: final_scores[p] = ( 1 - mix_square_ratio_scoring) * final_scores[p] + ( mix_square_ratio_scoring * current_scores[p].square_ratio) result = (set_orders_dict, final_game_scores) return result, timings
def preprocess(self): """ Pre-processes dataset :return: """ assert not self.debug_only_opening_phase, "FIXME" logging.info( f"Building dataset from {len(self.game_ids)} games, " f"only_with_min_final_score={self.only_with_min_final_score} " f"value_decay_alpha={self.value_decay_alpha} cf_agent={self.cf_agent}" ) torch.set_num_threads(1) encoder = FeatureEncoder() encoded_game_tuples = self.mp_encode_games() encoded_games = [g for (_, g) in encoded_game_tuples if g is not None ] # remove "empty" games (e.g. json didn't exist) logging.info( f"Found data for {len(encoded_games)} / {len(self.game_ids)} games" ) encoded_games = [ g for g in encoded_games if g["valid_power_idxs"][0].any() ] logging.info( f"{len(encoded_games)} games had data for at least one power") # Update game_ids self.game_ids = [ g_id for (g_id, g) in encoded_game_tuples if g_id is not None and g["valid_power_idxs"][0].any() ] game_idxs, phase_idxs, power_idxs, x_idxs = [], [], [], [] x_idx = 0 for game_idx, encoded_game in enumerate(encoded_games): for phase_idx, valid_power_idxs in enumerate( encoded_game["valid_power_idxs"]): assert valid_power_idxs.nelement() == len(POWERS), ( encoded_game["valid_power_idxs"].shape, valid_power_idxs.shape, ) for power_idx in valid_power_idxs.nonzero()[:, 0]: game_idxs.append(game_idx) phase_idxs.append(phase_idx) power_idxs.append(power_idx) x_idxs.append(x_idx) x_idx += 1 self.game_idxs = torch.tensor(game_idxs, dtype=torch.long) self.phase_idxs = torch.tensor(phase_idxs, dtype=torch.long) self.power_idxs = torch.tensor(power_idxs, dtype=torch.long) self.x_idxs = torch.tensor(x_idxs, dtype=torch.long) # now collate the data into giant tensors! self.encoded_games = DataFields.cat(encoded_games) self.num_games = len(encoded_games) self.num_phases = len( self.encoded_games["x_board_state"]) if self.encoded_games else 0 self.num_elements = len(self.x_idxs) self.validate_dataset() self._preprocessed = True
def encode_phase( encoder: FeatureEncoder, game: Game, game_id: str, phase_idx: int, *, only_with_min_final_score: Optional[int], cf_agent=None, n_cf_agent_samples=1, value_decay_alpha, input_valid_power_idxs, exclude_n_holds, ): """ Arguments: - game: Game object - game_id: unique id for game - phase_idx: int, the index of the phase to encode - only_with_min_final_score: if specified, only encode for powers who finish the game with some # of supply centers (i.e. only learn from winners). MILA uses 7. Returns: DataFields, including y_actions and y_final_score """ phase = game.get_phase_history()[phase_idx] rolled_back_game = game.rolled_back_to_phase_start(phase.name) data_fields = encoder.encode_inputs([rolled_back_game]) # encode final scores y_final_scores = encode_weighted_sos_scores(game, phase_idx, value_decay_alpha) # encode actions valid_power_idxs = torch.tensor(input_valid_power_idxs, dtype=torch.bool) # print('valid_power_idxs', valid_power_idxs) y_actions_lst = [] power_orders_samples = ({ power: [phase.orders.get(power, [])] for power in POWERS } if cf_agent is None else get_cf_agent_order_samples( rolled_back_game, phase.name, cf_agent, n_cf_agent_samples)) for power_i, power in enumerate(POWERS): orders_samples = power_orders_samples[power] if len(orders_samples) == 0: valid_power_idxs[power_i] = False y_actions_lst.append( torch.empty(n_cf_agent_samples, MAX_SEQ_LEN, dtype=torch.int32).fill_(EOS_IDX)) continue encoded_power_actions_lst = [] for orders in orders_samples: encoded_power_actions, valid = encode_power_actions( orders, data_fields["x_possible_actions"][0, power_i]) encoded_power_actions_lst.append(encoded_power_actions) if 0 <= exclude_n_holds <= len(orders): if all(o.endswith(" H") for o in orders): valid = 0 valid_power_idxs[power_i] &= valid y_actions_lst.append(torch.stack(encoded_power_actions_lst, dim=0)) # [N, 17] y_actions = torch.stack(y_actions_lst, dim=0) # [7, N, 17] # filter away powers that have no orders valid_power_idxs &= (y_actions != EOS_IDX).any(dim=2).all(dim=1) assert valid_power_idxs.ndimension() == 1 # Maybe filter away powers that don't finish with enough SC. # If all players finish with fewer SC, include everybody. # cf. get_top_victors() in mila's state_space.py if only_with_min_final_score is not None: final_score = { k: len(v) for k, v in game.get_state()["centers"].items() } if max(final_score.values()) >= only_with_min_final_score: for i, power in enumerate(POWERS): if final_score.get(power, 0) < only_with_min_final_score: valid_power_idxs[i] = 0 data_fields["y_final_scores"] = y_final_scores.unsqueeze(0) data_fields["y_actions"] = y_actions.unsqueeze(0) data_fields["valid_power_idxs"] = valid_power_idxs.unsqueeze(0) data_fields["x_possible_actions"] = TensorList.from_padded( data_fields["x_possible_actions"].view( len(POWERS) * MAX_SEQ_LEN, MAX_VALID_LEN), padding_value=EOS_IDX, ) return data_fields
def encode_game( game_id: Union[int, str], data_dir: str, only_with_min_final_score=7, *, cf_agent=None, n_cf_agent_samples=1, input_valid_power_idxs, value_decay_alpha, game_metadata, exclude_n_holds, ): """ Arguments: - game: Game object - only_with_min_final_score: if specified, only encode for powers who finish the game with some # of supply centers (i.e. only learn from winners). MILA uses 7. - input_valid_power_idxs: bool tensor, true if power should a priori be included in the dataset based on e.g. player rating) Return: game_id, DataFields dict of tensors: L is game length, P is # of powers above min_final_score, N is n_cf_agent_samples - board_state: shape=(L, 81, 35) - prev_state: shape=(L, 81, 35) - prev_orders: shape=(L, 2, 100), dtype=long - power: shape=(L, 7, 7) - season: shape=(L, 3) - in_adj_phase: shape=(L, 1) - build_numbers: shape=(L, 7) - final_scores: shape=(L, 7) - possible_actions: TensorList shape=(L x 7, 17 x 469) - loc_idxs: shape=(L, 7, 81), int8 - actions: shape=(L, 7, N, 17) int order idxs, N=n_cf_agent_samples - valid_power_idxs: shape=(L, 7) bool mask of valid powers at each phase """ torch.set_num_threads(1) encoder = FeatureEncoder() if isinstance(game_id, str): game_path = game_id else: # Hacky fix to handle game_ids that are paths. game_path = os.path.join(f"{data_dir}", f"game_{game_id}.json") try: with open(game_path) as f: game = Game.from_json(f.read()) except (FileNotFoundError, json.decoder.JSONDecodeError) as e: print(f"Error while loading game at {game_path}: {e}") return None, None num_phases = len(game.get_phase_history()) logging.info(f"Encoding {game.game_id} with {num_phases} phases") phase_encodings = [ encode_phase( encoder, game, game_id, phase_idx, only_with_min_final_score=only_with_min_final_score, cf_agent=cf_agent, n_cf_agent_samples=n_cf_agent_samples, value_decay_alpha=value_decay_alpha, input_valid_power_idxs=input_valid_power_idxs, exclude_n_holds=exclude_n_holds, ) for phase_idx in range(num_phases) ] stacked_encodings = DataFields.cat(phase_encodings) return game_id, stacked_encodings.to_storage_fmt_()
def do_rollouts(self, game_init, set_orders_dicts, average_n_rollouts=1, timings=None, log_timings=False): if timings is None: timings = TimingCtx() if self.clear_old_all_possible_orders: with timings("clear_old_orders"): game_init = pydipcc.Game(game_init) game_init.clear_old_all_possible_orders() with timings("clone"): games = [ pydipcc.Game(game_init) # clones for _ in set_orders_dicts for _ in range(average_n_rollouts) ] with timings("setup"): for i in range(len(games)): games[i].game_id += f"_{i}" est_final_scores = {} # game id -> np.array len=7 # set orders if specified for game, set_orders_dict in zip( games, repeat(set_orders_dicts, average_n_rollouts)): for power, orders in set_orders_dict.items(): game.set_orders(power, list(orders)) # for each game, a list of powers whose orders need to be generated # by the model on the first phase. missing_start_orders = { game.game_id: frozenset(p for p in POWERS if p not in set_orders_dict) for game, set_orders_dict in zip( games, repeat(set_orders_dicts, average_n_rollouts)) } if self.max_rollout_length > 0: rollout_end_phase_id = sort_phase_key( n_move_phases_later(game_init.current_short_phase, self.max_rollout_length)) max_steps = 1000000 else: # Really far ahead. rollout_end_phase_id = sort_phase_key( n_move_phases_later(game_init.current_short_phase, 10)) max_steps = 1 # This loop steps the games until one of the conditions is true: # - all games are done # - at least one game was stepped for max_steps steps # - all games are either completed or reach a phase such that # sort_phase_key(phase) >= rollout_end_phase_id for step_id in range(max_steps): ongoing_game_phases = [ game.current_short_phase for game in games if not game.is_game_done ] if len(ongoing_game_phases) == 0: # all games are done break # step games together at the pace of the slowest game, e.g. process # games with retreat phases alone before moving on to the next move phase min_phase = min(ongoing_game_phases, key=sort_phase_key) if sort_phase_key(min_phase) >= rollout_end_phase_id: break games_to_step = [ game for game in games if not game.is_game_done and game.current_short_phase == min_phase ] if step_id > 0 or any(missing_start_orders.values()): with timings("encoding"): batch_inputs = FeatureEncoder().encode_inputs( games_to_step) batch_orders, _, _ = self.do_model_request( batch_inputs, self.rollout_temperature, self.rollout_top_p, timings=timings) with timings("env.set_orders"): assert len(games_to_step) == len(batch_orders) for game, orders_per_power in zip(games_to_step, batch_orders): for power, orders in zip(POWERS, orders_per_power): if step_id == 0 and power not in missing_start_orders[ game.game_id]: continue game.set_orders(power, list(orders)) with timings("env.step"): self.thread_pool.process_multi( [game for game in games_to_step]) # Compute SoS for done game and query the net for not-done games. not_done_games = [game for game in games if not game.is_game_done] if not_done_games: with timings("encoding"): batch_inputs = FeatureEncoder().encode_inputs(not_done_games) batch_est_final_scores = self.do_model_request( batch_inputs, self.rollout_temperature, self.rollout_top_p, timings=timings, values_only=True, ) for game_idx, game in enumerate(not_done_games): est_final_scores[game.game_id] = np.array( batch_est_final_scores[game_idx]) for game in games: if game.is_game_done: est_final_scores[game.game_id] = np.array( get_square_scores_from_game(game)) with timings("final_scores"): final_game_scores = [ dict(zip(POWERS, est_final_scores[game.game_id])) for game in games ] # mix in current sum of squares ratio to encourage losing powers to try hard # get GameScores objects for current game state if self.mix_square_ratio_scoring > 0: for game, final_scores in zip(games, final_game_scores): current_scores = game.get_square_scores() for pi, p in enumerate(POWERS): final_scores[p] = (1 - self.mix_square_ratio_scoring ) * final_scores[p] + ( self.mix_square_ratio_scoring * current_scores[pi]) r = [(set_orders_dict, average_score_dicts(scores_dicts)) for set_orders_dict, scores_dicts in zip( set_orders_dicts, groups_of(final_game_scores, average_n_rollouts))] if log_timings: timings.pprint(logging.getLogger("timings").info) return r
def yield_rollouts( *, exploit_hostports: Sequence[str], blueprint_hostports: Optional[Sequence[str]], game_json_paths: Optional[Sequence[str]], mode: RolloutMode, fast_finish: bool = False, temperature=0.05, max_rollout_length=40, batch_size=1, seed=0, ) -> Generator[ExploitRollout, None, None]: """Do non-stop rollout for 1 (exploit) vs 6 (blueprint). This method can safely be called in a subprocess Arguments: - exploit_hostports: list of "{host}:{port}" of model servers for the agents that is training. - blueprint_hostports: list of "{host}:{port}" of model servers for the agents that is exploited. Ignored in SELFPLAY model - game_jsons: either None or a list of paths to json-serialized games. - mode: what kind of rollout to do. Defiens what will be outputed. - fast_finish: if True, the rollout is stopped once all non-blueprint agents lost. - temperature: model softmax temperature for rollout policy on the blueprint agent. - max_rollout_length: return SC count after at most # steps - batch_size: rollout # of games in parallel - seed: random seed. yields a ExploitRollout. """ timings = TimingCtx() def create_client_selector(hostports): clients = [] for hostport in hostports: client = postman.Client(hostport) client.connect(3) clients.append(client) return iter(itertools.cycle(clients)) exploit_client_selector = create_client_selector(exploit_hostports) if mode != RolloutMode.SELFPLAY: assert blueprint_hostports is not None blueprint_client_selector = create_client_selector(blueprint_hostports) faulthandler.register(signal.SIGUSR2) torch.set_num_threads(1) exploit_power_selector = itertools.cycle(tuple(range(len(POWERS)))) for _ in range(seed % len(POWERS)): next(exploit_power_selector) def yield_game(): nonlocal game_json_paths nonlocal seed while True: if game_json_paths is None: yield Game() else: rng = np.random.RandomState(seed=seed) p = game_json_paths[rng.choice(len(game_json_paths))] with open(p) as stream: game_serialized = stream.read() yield Game.from_json(game_serialized) game_selector = yield_game() for exploit_power_id in exploit_power_selector: if mode == RolloutMode.SELFPLAY: # Not used. del exploit_power_id exploit_ids = frozenset(range(len(POWERS))) else: exploit_ids = frozenset([exploit_power_id]) with timings("setup"): games = [next(game_selector) for _ in range(batch_size)] first_phases = [len(game.get_phase_history()) for game in games] turn_idx = 0 observations = {i: [] for i in range(batch_size)} actions = {i: [] for i in range(batch_size)} cand_actions = {i: [] for i in range(batch_size)} logprobs = {i: [] for i in range(batch_size)} while turn_idx < max_rollout_length: with timings("prep"): batch_data = [] for batch_idx, game in enumerate(games): if game.is_game_done: continue if fast_finish: last_centers = game.get_state()["centers"] if not any( len(last_centers[p]) > 0 for i, p in enumerate(POWERS) if i not in exploit_ids ): continue inputs = FeatureEncoder().encode_inputs([game]) batch_data.append((game, inputs, batch_idx)) if not batch_data: # All games are done. break with timings("cat_pad"): xs: List[Tuple] = [b[1] for b in batch_data] batch_inputs = cat_pad_inputs(xs) with timings("model"): if mode != RolloutMode.SELFPLAY: (blueprint_batch_order_ids,) = do_model_request( next(blueprint_client_selector), batch_inputs, temperature ) ( exploit_batch_order_ids, exploit_cand_ids, exploit_order_logprobs, ) = do_model_request(next(exploit_client_selector), batch_inputs) with timings("merging"): if mode == RolloutMode.SELFPLAY: batch_order_idx = exploit_batch_order_ids else: # Using all orders from the blueprint model except for ones for the epxloit power. batch_order_idx = blueprint_batch_order_ids batch_order_idx[:, exploit_power_id] = exploit_batch_order_ids[ :, exploit_power_id ] batch_orders = strigify_orders_idxs(batch_order_idx) with timings("env"): assert len(batch_data) == len(batch_orders), "{} != {}".format( len(batch_data), len(batch_orders) ) # set_orders and process for (game, _, _), power_orders in zip(batch_data, batch_orders): for power, orders in zip(POWERS, power_orders): game.set_orders(power, list(orders)) for inner_index, (game, _, i) in enumerate(batch_data): if not game.is_game_done: game.process() if mode == RolloutMode.SELFPLAY: actions[i].append(exploit_batch_order_ids[inner_index]) cand_actions[i].append(exploit_cand_ids[inner_index]) logprobs[i].append(exploit_order_logprobs[inner_index]) else: actions[i].append( exploit_batch_order_ids[inner_index, exploit_power_id] ) cand_actions[i].append(exploit_cand_ids[inner_index, exploit_power_id]) logprobs[i].append( exploit_order_logprobs[inner_index, exploit_power_id] ) observations[i].append(batch_inputs.select(inner_index)) turn_idx += 1 logging.debug( f"end do_rollout pid {os.getpid()} for {batch_size} games in {turn_idx} turns. timings: " f"{ {k : float('{:.3}'.format(v)) for k, v in timings.items()} }." ) for i in range(batch_size): final_game_json = json.loads(games[i].to_json()) if mode == RolloutMode.SELFPLAY: for power_id in range(len(POWERS)): extended_obs = DataFields.stack(observations[i]) extended_obs["cand_indices"] = torch.stack( [x[power_id] for x in cand_actions[i]], 0 ) yield ExploitRollout( power_id=power_id, game_json=final_game_json, actions=torch.stack([x[power_id] for x in actions[i]], 0), logprobs=torch.stack([x[power_id] for x in logprobs[i]], 0), observations=extended_obs, first_phase=first_phases[i], ) else: extended_obs = DataFields.stack(observations[i]) extended_obs["cand_indices"] = torch.stack(cand_actions[i], 0) yield ExploitRollout( power_id=exploit_power_id, game_json=final_game_json, actions=torch.stack(actions[i], 0), logprobs=torch.stack(logprobs[i], 0), observations=extended_obs, first_phase=first_phases[i], )
def get_plausible_orders( self, game, *, n=1000, temperature=1.0, limit: Union[int, Sequence[int]], # limit, or list of limits per power batch_size=500, top_p=1.0, ) -> Dict[str, Dict[Tuple[str], float]]: assert n % batch_size == 0, f"{n}, {batch_size}" # limits is a list of 7 limits limits = [limit] * 7 if type(limit) == int else limit assert len(limits) == 7 del limit # trivial return case: all powers have at most `limit` actions # orderable_locs = game.get_orderable_locations() # if max(map(len, orderable_locs.values())) <= 1: # all_orders = game.get_all_possible_orders() # pow_orders = { # p: all_orders[orderable_locs[p][0]] if orderable_locs[p] else [] for p in POWERS # } # if all(len(pow_orders[p]) <= limit for p, limit in zip(POWERS, limits)): # return {p: set((x,) for x in orders) for p, orders in pow_orders.items()} # non-trivial return case: query model counters = {p: Counter() for p in POWERS} x = [FeatureEncoder().encode_inputs([game])] * n orders_to_logprobs = {} for x_chunk in [x[i : i + batch_size] for i in range(0, n, batch_size)]: batch_inputs = self.cat_pad_inputs(x_chunk) batch_orders, batch_order_logprobs, _ = self.do_model_request( batch_inputs, temperature, top_p ) batch_orders = list(zip(*batch_orders)) # power -> list[orders] batch_order_logprobs = batch_order_logprobs.t() # [7 x B] for p, power in enumerate(POWERS): counters[power].update(batch_orders[p]) # slow and steady for power_orders, power_scores in zip(batch_orders, batch_order_logprobs): for order, score in zip(power_orders, power_scores): if order not in orders_to_logprobs: orders_to_logprobs[order] = score assert ( abs(orders_to_logprobs[order] - score) < 1e-2 ), f"{order} : {orders_to_logprobs[order]} != {score}" logging.info( "get_plausible_orders(n={}, t={}) found {} unique sets, choosing top {}".format( n, temperature, list(map(len, counters.values())), limits ) ) # filter out badly-coordinated actions counters = { power: ( filter_keys(counter, self.is_plausible_orders) if len(counter) > limit else counter ) for (power, counter), limit in zip(counters.items(), limits) } most_common = { power: sorted(counter.most_common(), key=lambda o: -orders_to_logprobs[o[0]])[:limit] for (power, counter), limit in zip(counters.items(), limits) } # # choose most common # most_common = { # power: counter.most_common(limit) # for (power, counter), limit in zip(counters.items(), limits) # } try: logging.info( "get_plausible_orders filtered down to {} unique sets, n_0={}, n_cut={}".format( list(map(len, counters.values())), [safe_idx(most_common[p], 0, default=(None, None))[1] for p in POWERS], [ safe_idx(most_common[p], limit - 1, default=(None, None))[1] for (p, limit) in zip(POWERS, limits) ], ) ) except: # TODO: remove this if not seen in production logging.warning("error in get_plausible_orders logging") logging.info("Plausible orders:") logging.info(" count,count_frac,prob") for power, orders_and_counts in most_common.items(): logging.info(f" {power}") for orders, count in orders_and_counts: logging.info( f" {count:5d} {count/n:10.5f} {np.exp(orders_to_logprobs[orders]):10.5f} {orders}" ) return { power: {orders: orders_to_logprobs[orders] for orders, _ in orders_and_counts} for power, orders_and_counts in most_common.items() }