def cat_pad_inputs(xs: List[DataFields]) -> DataFields:
    batch = DataFields({k: [x[k] for x in xs] for k in xs[0].keys()})
    for k, v in batch.items():
        if k == "x_possible_actions":
            batch[k] = cat_pad_sequences(v,
                                         pad_value=-1,
                                         pad_to_len=MAX_SEQ_LEN)
        elif k == "x_loc_idxs":
            batch[k] = cat_pad_sequences(v,
                                         pad_value=EOS_IDX,
                                         pad_to_len=MAX_SEQ_LEN)
        else:
            batch[k] = torch.cat(v)

    return batch
Exemple #2
0
def rollout_to_batch(
    rollout: ExploitRollout, reward_kwargs: Dict
) -> Tuple[RolloutBatch, ScoreDict]:
    # In case rollout started from an existing game.
    offset = rollout.first_phase
    assert len(rollout.game_json["phases"]) == len(rollout.actions) + 1 + offset, (
        len(rollout.game_json["phases"]),
        len(rollout.actions),
    )
    N = len(rollout.actions)
    scores = game_scoring.compute_game_scores(rollout.power_id, rollout.game_json)._asdict()

    rewards = compute_reward(rollout, **reward_kwargs)

    is_final = torch.zeros([N], dtype=torch.bool)
    is_final[-1] = True

    # Prepare observation to be used for training. Drop information about all
    # powers, but current.
    obs = DataFields(rollout.observations)
    obs["x_loc_idxs"] = obs["x_loc_idxs"][:, rollout.power_id].clone()
    obs["x_possible_actions"] = obs["x_possible_actions"][:, rollout.power_id].clone()

    rollout_batch = RolloutBatch(
        power_ids=torch.full([N], rollout.power_id, dtype=torch.long),
        rewards=rewards,
        observations=obs,
        actions=rollout.actions,
        logprobs=rollout.logprobs,
        done=is_final,
    )
    return rollout_batch, scores
Exemple #3
0
def _join_batches(batches: Sequence[RolloutBatch]) -> RolloutBatch:
    merged = {}
    for k in RolloutBatch._fields:
        values = []
        for b in batches:
            values.append(getattr(b, k))
        if k == "observations":
            values = DataFields.cat(values)
        else:
            values = torch.cat(values, 0)
        merged[k] = values
    return RolloutBatch(**merged)
    def do_model_request(
            cls, x: DataFields, temperature: float, top_p: float, *,
            client: postman.Client
    ) -> Tuple[List[List[Tuple[str]]], np.ndarray]:
        """Synchronous request to model server

        Arguments:
        - x: a DataFields dict of Tensors, where each tensor's dim=0 is the batch dim
        - temperature: model softmax temperature
        - top_p: probability mass to samples from

        Returns:
        - a list (len = batch size) of lists (len=7) of order-tuples
        - [7] float32 array of estimated final scores
        """
        B = x["x_board_state"].shape[0]
        x["temperature"] = torch.zeros(B, 1).fill_(temperature)
        x["top_p"] = torch.zeros(B, 1).fill_(top_p)
        try:
            order_idxs, order_logprobs, final_scores = client.evaluate(x)
        except Exception:
            logging.error("Caught server error with inputs {}".format([
                (k, v.shape, v.dtype) for k, v in x.items()
            ]))
            raise
        assert x["x_board_state"].shape[0] == final_scores.shape[0], (
            x["x_board_state"].shape[0],
            final_scores.shape[0],
        )
        if hasattr(final_scores, "numpy"):
            final_scores = final_scores.numpy()

        return (
            [[
                tuple(decode_order_idxs(order_idxs[b, p, :]))
                for p in range(len(POWERS))
            ] for b in range(order_idxs.shape[0])],
            order_logprobs,
            final_scores,
        )
    def from_merge(cls,
                   datasets: Sequence["Dataset"]) -> torch.utils.data.Dataset:
        for d in datasets:
            if d.n_cf_agent_samples != 1:
                raise NotImplementedError()
            if not d._preprocessed:
                raise NotImplementedError()

        merged = Dataset(
            game_ids=[x for d in datasets for x in d.game_ids],
            data_dir=None,
            game_metadata=None,
            only_with_min_final_score=None,
            num_dataloader_workers=None,
            value_decay_alpha=None,
        )

        game_offsets = torch.from_numpy(
            np.cumsum([0] + [d.num_games for d in datasets[:-1]]))
        phase_offsets = torch.from_numpy(
            np.cumsum([0] + [d.num_phases for d in datasets[:-1]]))

        merged.game_idxs = torch.cat(
            [d.game_idxs + off for d, off in zip(datasets, game_offsets)])
        merged.phase_idxs = torch.cat([d.phase_idxs for d in datasets])
        merged.power_idxs = torch.cat([d.power_idxs for d in datasets])
        merged.x_idxs = torch.cat(
            [d.x_idxs + off for d, off in zip(datasets, phase_offsets)])
        merged.encoded_games = DataFields.cat(
            [d.encoded_games for d in datasets])
        merged.num_games = sum(d.num_games for d in datasets)
        merged.num_phases = sum(d.num_phases for d in datasets)
        merged.num_elements = sum(d.num_elements for d in datasets)
        merged._preprocessed = True

        return merged
    def preprocess(self):
        """
        Pre-processes dataset
        :return:
        """
        assert not self.debug_only_opening_phase, "FIXME"

        logging.info(
            f"Building dataset from {len(self.game_ids)} games, "
            f"only_with_min_final_score={self.only_with_min_final_score} "
            f"value_decay_alpha={self.value_decay_alpha} cf_agent={self.cf_agent}"
        )

        torch.set_num_threads(1)
        encoder = FeatureEncoder()
        encoded_game_tuples = self.mp_encode_games()

        encoded_games = [g for (_, g) in encoded_game_tuples if g is not None
                         ]  # remove "empty" games (e.g. json didn't exist)

        logging.info(
            f"Found data for {len(encoded_games)} / {len(self.game_ids)} games"
        )

        encoded_games = [
            g for g in encoded_games if g["valid_power_idxs"][0].any()
        ]
        logging.info(
            f"{len(encoded_games)} games had data for at least one power")

        # Update game_ids
        self.game_ids = [
            g_id for (g_id, g) in encoded_game_tuples
            if g_id is not None and g["valid_power_idxs"][0].any()
        ]

        game_idxs, phase_idxs, power_idxs, x_idxs = [], [], [], []
        x_idx = 0
        for game_idx, encoded_game in enumerate(encoded_games):
            for phase_idx, valid_power_idxs in enumerate(
                    encoded_game["valid_power_idxs"]):
                assert valid_power_idxs.nelement() == len(POWERS), (
                    encoded_game["valid_power_idxs"].shape,
                    valid_power_idxs.shape,
                )
                for power_idx in valid_power_idxs.nonzero()[:, 0]:
                    game_idxs.append(game_idx)
                    phase_idxs.append(phase_idx)
                    power_idxs.append(power_idx)
                    x_idxs.append(x_idx)
                x_idx += 1

        self.game_idxs = torch.tensor(game_idxs, dtype=torch.long)
        self.phase_idxs = torch.tensor(phase_idxs, dtype=torch.long)
        self.power_idxs = torch.tensor(power_idxs, dtype=torch.long)
        self.x_idxs = torch.tensor(x_idxs, dtype=torch.long)

        # now collate the data into giant tensors!
        self.encoded_games = DataFields.cat(encoded_games)

        self.num_games = len(encoded_games)
        self.num_phases = len(
            self.encoded_games["x_board_state"]) if self.encoded_games else 0
        self.num_elements = len(self.x_idxs)

        self.validate_dataset()

        self._preprocessed = True
def encode_game(
    game_id: Union[int, str],
    data_dir: str,
    only_with_min_final_score=7,
    *,
    cf_agent=None,
    n_cf_agent_samples=1,
    input_valid_power_idxs,
    value_decay_alpha,
    game_metadata,
    exclude_n_holds,
):
    """
    Arguments:
    - game: Game object
    - only_with_min_final_score: if specified, only encode for powers who
      finish the game with some # of supply centers (i.e. only learn from
      winners). MILA uses 7.
    - input_valid_power_idxs: bool tensor, true if power should a priori be included in
      the dataset based on e.g. player rating)
    Return: game_id, DataFields dict of tensors:
    L is game length, P is # of powers above min_final_score, N is n_cf_agent_samples
    - board_state: shape=(L, 81, 35)
    - prev_state: shape=(L, 81, 35)
    - prev_orders: shape=(L, 2, 100), dtype=long
    - power: shape=(L, 7, 7)
    - season: shape=(L, 3)
    - in_adj_phase: shape=(L, 1)
    - build_numbers: shape=(L, 7)
    - final_scores: shape=(L, 7)
    - possible_actions: TensorList shape=(L x 7, 17 x 469)
    - loc_idxs: shape=(L, 7, 81), int8
    - actions: shape=(L, 7, N, 17) int order idxs, N=n_cf_agent_samples
    - valid_power_idxs: shape=(L, 7) bool mask of valid powers at each phase
    """

    torch.set_num_threads(1)
    encoder = FeatureEncoder()

    if isinstance(game_id, str):
        game_path = game_id
    else:  # Hacky fix to handle game_ids that are paths.
        game_path = os.path.join(f"{data_dir}", f"game_{game_id}.json")

    try:
        with open(game_path) as f:
            game = Game.from_json(f.read())
    except (FileNotFoundError, json.decoder.JSONDecodeError) as e:
        print(f"Error while loading game at {game_path}: {e}")
        return None, None

    num_phases = len(game.get_phase_history())
    logging.info(f"Encoding {game.game_id} with {num_phases} phases")

    phase_encodings = [
        encode_phase(
            encoder,
            game,
            game_id,
            phase_idx,
            only_with_min_final_score=only_with_min_final_score,
            cf_agent=cf_agent,
            n_cf_agent_samples=n_cf_agent_samples,
            value_decay_alpha=value_decay_alpha,
            input_valid_power_idxs=input_valid_power_idxs,
            exclude_n_holds=exclude_n_holds,
        ) for phase_idx in range(num_phases)
    ]

    stacked_encodings = DataFields.cat(phase_encodings)

    return game_id, stacked_encodings.to_storage_fmt_()
Exemple #8
0
 def encode_inputs_state_only(self,
                              games: Sequence[pydipcc.Game]) -> DataFields:
     return DataFields(
         self.thread_pool.encode_inputs_state_only_multi(games))