def cat_pad_inputs(xs: List[DataFields]) -> DataFields: batch = DataFields({k: [x[k] for x in xs] for k in xs[0].keys()}) for k, v in batch.items(): if k == "x_possible_actions": batch[k] = cat_pad_sequences(v, pad_value=-1, pad_to_len=MAX_SEQ_LEN) elif k == "x_loc_idxs": batch[k] = cat_pad_sequences(v, pad_value=EOS_IDX, pad_to_len=MAX_SEQ_LEN) else: batch[k] = torch.cat(v) return batch
def rollout_to_batch( rollout: ExploitRollout, reward_kwargs: Dict ) -> Tuple[RolloutBatch, ScoreDict]: # In case rollout started from an existing game. offset = rollout.first_phase assert len(rollout.game_json["phases"]) == len(rollout.actions) + 1 + offset, ( len(rollout.game_json["phases"]), len(rollout.actions), ) N = len(rollout.actions) scores = game_scoring.compute_game_scores(rollout.power_id, rollout.game_json)._asdict() rewards = compute_reward(rollout, **reward_kwargs) is_final = torch.zeros([N], dtype=torch.bool) is_final[-1] = True # Prepare observation to be used for training. Drop information about all # powers, but current. obs = DataFields(rollout.observations) obs["x_loc_idxs"] = obs["x_loc_idxs"][:, rollout.power_id].clone() obs["x_possible_actions"] = obs["x_possible_actions"][:, rollout.power_id].clone() rollout_batch = RolloutBatch( power_ids=torch.full([N], rollout.power_id, dtype=torch.long), rewards=rewards, observations=obs, actions=rollout.actions, logprobs=rollout.logprobs, done=is_final, ) return rollout_batch, scores
def _join_batches(batches: Sequence[RolloutBatch]) -> RolloutBatch: merged = {} for k in RolloutBatch._fields: values = [] for b in batches: values.append(getattr(b, k)) if k == "observations": values = DataFields.cat(values) else: values = torch.cat(values, 0) merged[k] = values return RolloutBatch(**merged)
def do_model_request( cls, x: DataFields, temperature: float, top_p: float, *, client: postman.Client ) -> Tuple[List[List[Tuple[str]]], np.ndarray]: """Synchronous request to model server Arguments: - x: a DataFields dict of Tensors, where each tensor's dim=0 is the batch dim - temperature: model softmax temperature - top_p: probability mass to samples from Returns: - a list (len = batch size) of lists (len=7) of order-tuples - [7] float32 array of estimated final scores """ B = x["x_board_state"].shape[0] x["temperature"] = torch.zeros(B, 1).fill_(temperature) x["top_p"] = torch.zeros(B, 1).fill_(top_p) try: order_idxs, order_logprobs, final_scores = client.evaluate(x) except Exception: logging.error("Caught server error with inputs {}".format([ (k, v.shape, v.dtype) for k, v in x.items() ])) raise assert x["x_board_state"].shape[0] == final_scores.shape[0], ( x["x_board_state"].shape[0], final_scores.shape[0], ) if hasattr(final_scores, "numpy"): final_scores = final_scores.numpy() return ( [[ tuple(decode_order_idxs(order_idxs[b, p, :])) for p in range(len(POWERS)) ] for b in range(order_idxs.shape[0])], order_logprobs, final_scores, )
def from_merge(cls, datasets: Sequence["Dataset"]) -> torch.utils.data.Dataset: for d in datasets: if d.n_cf_agent_samples != 1: raise NotImplementedError() if not d._preprocessed: raise NotImplementedError() merged = Dataset( game_ids=[x for d in datasets for x in d.game_ids], data_dir=None, game_metadata=None, only_with_min_final_score=None, num_dataloader_workers=None, value_decay_alpha=None, ) game_offsets = torch.from_numpy( np.cumsum([0] + [d.num_games for d in datasets[:-1]])) phase_offsets = torch.from_numpy( np.cumsum([0] + [d.num_phases for d in datasets[:-1]])) merged.game_idxs = torch.cat( [d.game_idxs + off for d, off in zip(datasets, game_offsets)]) merged.phase_idxs = torch.cat([d.phase_idxs for d in datasets]) merged.power_idxs = torch.cat([d.power_idxs for d in datasets]) merged.x_idxs = torch.cat( [d.x_idxs + off for d, off in zip(datasets, phase_offsets)]) merged.encoded_games = DataFields.cat( [d.encoded_games for d in datasets]) merged.num_games = sum(d.num_games for d in datasets) merged.num_phases = sum(d.num_phases for d in datasets) merged.num_elements = sum(d.num_elements for d in datasets) merged._preprocessed = True return merged
def preprocess(self): """ Pre-processes dataset :return: """ assert not self.debug_only_opening_phase, "FIXME" logging.info( f"Building dataset from {len(self.game_ids)} games, " f"only_with_min_final_score={self.only_with_min_final_score} " f"value_decay_alpha={self.value_decay_alpha} cf_agent={self.cf_agent}" ) torch.set_num_threads(1) encoder = FeatureEncoder() encoded_game_tuples = self.mp_encode_games() encoded_games = [g for (_, g) in encoded_game_tuples if g is not None ] # remove "empty" games (e.g. json didn't exist) logging.info( f"Found data for {len(encoded_games)} / {len(self.game_ids)} games" ) encoded_games = [ g for g in encoded_games if g["valid_power_idxs"][0].any() ] logging.info( f"{len(encoded_games)} games had data for at least one power") # Update game_ids self.game_ids = [ g_id for (g_id, g) in encoded_game_tuples if g_id is not None and g["valid_power_idxs"][0].any() ] game_idxs, phase_idxs, power_idxs, x_idxs = [], [], [], [] x_idx = 0 for game_idx, encoded_game in enumerate(encoded_games): for phase_idx, valid_power_idxs in enumerate( encoded_game["valid_power_idxs"]): assert valid_power_idxs.nelement() == len(POWERS), ( encoded_game["valid_power_idxs"].shape, valid_power_idxs.shape, ) for power_idx in valid_power_idxs.nonzero()[:, 0]: game_idxs.append(game_idx) phase_idxs.append(phase_idx) power_idxs.append(power_idx) x_idxs.append(x_idx) x_idx += 1 self.game_idxs = torch.tensor(game_idxs, dtype=torch.long) self.phase_idxs = torch.tensor(phase_idxs, dtype=torch.long) self.power_idxs = torch.tensor(power_idxs, dtype=torch.long) self.x_idxs = torch.tensor(x_idxs, dtype=torch.long) # now collate the data into giant tensors! self.encoded_games = DataFields.cat(encoded_games) self.num_games = len(encoded_games) self.num_phases = len( self.encoded_games["x_board_state"]) if self.encoded_games else 0 self.num_elements = len(self.x_idxs) self.validate_dataset() self._preprocessed = True
def encode_game( game_id: Union[int, str], data_dir: str, only_with_min_final_score=7, *, cf_agent=None, n_cf_agent_samples=1, input_valid_power_idxs, value_decay_alpha, game_metadata, exclude_n_holds, ): """ Arguments: - game: Game object - only_with_min_final_score: if specified, only encode for powers who finish the game with some # of supply centers (i.e. only learn from winners). MILA uses 7. - input_valid_power_idxs: bool tensor, true if power should a priori be included in the dataset based on e.g. player rating) Return: game_id, DataFields dict of tensors: L is game length, P is # of powers above min_final_score, N is n_cf_agent_samples - board_state: shape=(L, 81, 35) - prev_state: shape=(L, 81, 35) - prev_orders: shape=(L, 2, 100), dtype=long - power: shape=(L, 7, 7) - season: shape=(L, 3) - in_adj_phase: shape=(L, 1) - build_numbers: shape=(L, 7) - final_scores: shape=(L, 7) - possible_actions: TensorList shape=(L x 7, 17 x 469) - loc_idxs: shape=(L, 7, 81), int8 - actions: shape=(L, 7, N, 17) int order idxs, N=n_cf_agent_samples - valid_power_idxs: shape=(L, 7) bool mask of valid powers at each phase """ torch.set_num_threads(1) encoder = FeatureEncoder() if isinstance(game_id, str): game_path = game_id else: # Hacky fix to handle game_ids that are paths. game_path = os.path.join(f"{data_dir}", f"game_{game_id}.json") try: with open(game_path) as f: game = Game.from_json(f.read()) except (FileNotFoundError, json.decoder.JSONDecodeError) as e: print(f"Error while loading game at {game_path}: {e}") return None, None num_phases = len(game.get_phase_history()) logging.info(f"Encoding {game.game_id} with {num_phases} phases") phase_encodings = [ encode_phase( encoder, game, game_id, phase_idx, only_with_min_final_score=only_with_min_final_score, cf_agent=cf_agent, n_cf_agent_samples=n_cf_agent_samples, value_decay_alpha=value_decay_alpha, input_valid_power_idxs=input_valid_power_idxs, exclude_n_holds=exclude_n_holds, ) for phase_idx in range(num_phases) ] stacked_encodings = DataFields.cat(phase_encodings) return game_id, stacked_encodings.to_storage_fmt_()
def encode_inputs_state_only(self, games: Sequence[pydipcc.Game]) -> DataFields: return DataFields( self.thread_pool.encode_inputs_state_only_multi(games))