示例#1
0
    def __init__(self, t_prof, seat_id, chief_handle):
        self.ddqn_args = t_prof.module_args["ddqn"]
        self.avg_args = t_prof.module_args["avg"]
        super().__init__(t_prof=t_prof, chief_handle=chief_handle)

        self.seat_id = seat_id
        self.global_iter_id = 0

        self.eps = self.ddqn_args.eps_start

        self.q_net = DuelingQNet(q_args=self.ddqn_args.q_args,
                                 env_bldr=self._env_bldr,
                                 device=self._device)
        self.avg_net = AvrgStrategyNet(
            avrg_net_args=self.avg_args.avg_net_args,
            env_bldr=self._env_bldr,
            device=self._device)

        self.br_optim = rl_util.str_to_optim_cls(self.ddqn_args.optim_str)(
            self.q_net.parameters(), lr=self.ddqn_args.lr)
        self.avg_optim = rl_util.str_to_optim_cls(self.avg_args.optim_str)(
            self.avg_net.parameters(), lr=self.avg_args.lr)

        self.eps_exp = self._ray.remote(
            self._chief_handle.create_experiment,
            t_prof.name + ": epsilon Plyr" + str(seat_id))
        self._log_eps()
示例#2
0
    def __init__(self,
                 env_bldr,
                 ddqn_args,
                 owner,
                 ):
        super().__init__(
            net=DuelingQNet(env_bldr=env_bldr, q_args=ddqn_args.q_args, device=ddqn_args.device_training),
            env_bldr=env_bldr,
            args=ddqn_args,
            owner=owner,
            device=ddqn_args.device_training,
        )

        self._eps = None

        self._target_net = DuelingQNet(env_bldr=env_bldr, q_args=ddqn_args.q_args, device=ddqn_args.device_training)
        self._target_net.eval()
        self.update_target_net()

        self._batch_arranged = torch.arange(ddqn_args.batch_size, dtype=torch.long, device=self.device)
        self._minus_e20 = torch.full((ddqn_args.batch_size, self._env_bldr.N_ACTIONS,),
                                     fill_value=-10e20,
                                     device=self.device,
                                     dtype=torch.float32,
                                     requires_grad=False)

        self._n_actions_arranged = np.arange(self._env_bldr.N_ACTIONS, dtype=np.int32).tolist()
示例#3
0
    def load_net_state_dict(self, state_dict):
        if state_dict is None:
            return  # if this happens (should only for iteration 0), this class will return random actions.
        else:
            self._adv_net = DuelingQNet(q_args=self._t_prof.module_args["adv_training"].adv_net_args,
                                        env_bldr=self._env_bldr, device=self._device)
            self._adv_net.load_state_dict(state_dict)
            self._adv_net.to(self._device)

        self._adv_net.eval()
        for param in self._adv_net.parameters():
            param.requires_grad = False
示例#4
0
 def __init__(self, env_bldr, adv_training_args, owner, device):
     super().__init__(
         net=DuelingQNet(env_bldr=env_bldr, q_args=adv_training_args.adv_net_args, device=device),
         env_bldr=env_bldr,
         args=adv_training_args,
         owner=owner,
         device=device
     )
示例#5
0
    def __init__(self, env_bldr, baseline_args):
        super().__init__(
            net=DuelingQNet(env_bldr=env_bldr, q_args=baseline_args.q_net_args, device=baseline_args.device_training),
            owner=None,
            env_bldr=env_bldr,
            args=baseline_args,
            device=baseline_args.device_training,
        )

        self._batch_arranged = torch.arange(self._args.batch_size, dtype=torch.long, device=self.device)
        self._minus_e20 = torch.full((self._args.batch_size, self._env_bldr.N_ACTIONS,),
                                     fill_value=-10e20,
                                     device=self.device,
                                     dtype=torch.float32,
                                     requires_grad=False)
示例#6
0
class DDQN(_NetWrapperBase):

    def __init__(self,
                 env_bldr,
                 ddqn_args,
                 owner,
                 ):
        super().__init__(
            net=DuelingQNet(env_bldr=env_bldr, q_args=ddqn_args.q_args, device=ddqn_args.device_training),
            env_bldr=env_bldr,
            args=ddqn_args,
            owner=owner,
            device=ddqn_args.device_training,
        )

        self._eps = None

        self._target_net = DuelingQNet(env_bldr=env_bldr, q_args=ddqn_args.q_args, device=ddqn_args.device_training)
        self._target_net.eval()
        self.update_target_net()

        self._batch_arranged = torch.arange(ddqn_args.batch_size, dtype=torch.long, device=self.device)
        self._minus_e20 = torch.full((ddqn_args.batch_size, self._env_bldr.N_ACTIONS,),
                                     fill_value=-10e20,
                                     device=self.device,
                                     dtype=torch.float32,
                                     requires_grad=False)

        self._n_actions_arranged = np.arange(self._env_bldr.N_ACTIONS, dtype=np.int32).tolist()

    @property
    def eps(self):
        return self._eps

    @eps.setter
    def eps(self, value):
        self._eps = value

    def select_br_a(self, pub_obses, range_idxs, legal_actions_lists, explore=False):
        if explore and (np.random.random() < self._eps):
            return np.array(
                [legal_actions[np.random.randint(len(legal_actions))] for legal_actions in legal_actions_lists]
            )

        with torch.no_grad():
            self.eval()
            range_idxs = torch.tensor(range_idxs, dtype=torch.long, device=self.device)
            q = self._net(pub_obses=pub_obses, range_idxs=range_idxs,
                          legal_action_masks=rl_util.batch_get_legal_action_mask_torch(
                              n_actions=self._env_bldr.N_ACTIONS,
                              legal_actions_lists=legal_actions_lists,
                              device=self.device,
                              dtype=torch.float32)).cpu().numpy()
            for b in range(q.shape[0]):
                illegal_actions = [i for i in self._n_actions_arranged if i not in legal_actions_lists[b]]
                if len(illegal_actions) > 0:
                    illegal_actions = np.array(illegal_actions)
                    q[b, illegal_actions] = -1e20

            return np.argmax(q, axis=1)

    def update_target_net(self):
        self._target_net.load_state_dict(self._net.state_dict())
        self._target_net.eval()

    def _mini_batch_loop(self, buffer, grad_mngr):
        batch_pub_obs_t, \
        batch_a_t, \
        batch_range_idx, \
        batch_legal_action_mask_t, \
        batch_r_t, \
        batch_pub_obs_tp1, \
        batch_legal_action_mask_tp1, \
        batch_done = \
            buffer.sample(device=self.device, batch_size=self._args.batch_size)

        # [batch_size, n_actions]
        q1_t = self._net(pub_obses=batch_pub_obs_t, range_idxs=batch_range_idx,
                         legal_action_masks=batch_legal_action_mask_t.to(torch.float32))
        q1_tp1 = self._net(pub_obses=batch_pub_obs_tp1, range_idxs=batch_range_idx,
                           legal_action_masks=batch_legal_action_mask_tp1.to(torch.float32)).detach()
        q2_tp1 = self._target_net(pub_obses=batch_pub_obs_tp1, range_idxs=batch_range_idx,
                                  legal_action_masks=batch_legal_action_mask_tp1.to(torch.float32)).detach()

        # ______________________________________________ TD Learning _______________________________________________
        # [batch_size]
        q1_t_of_a_selected = q1_t[self._batch_arranged, batch_a_t]

        # only consider allowed actions for tp1
        q1_tp1 = torch.where(batch_legal_action_mask_tp1,
                             q1_tp1,
                             self._minus_e20)

        # [batch_size]
        _, best_a_tp1 = q1_tp1.max(dim=-1, keepdim=False)
        q2_best_a_tp1 = q2_tp1[self._batch_arranged, best_a_tp1]

        q2_best_a_tp1 = q2_best_a_tp1 * (1.0 - batch_done)
        target = batch_r_t + q2_best_a_tp1

        grad_mngr.backprop(pred=q1_t_of_a_selected, target=target)

    def state_dict(self):
        return {
            "q_net": self._net.state_dict(),
            "target_net": self._target_net.state_dict(),
            "eps": self._eps,
            "owner": self.owner,
            "args": self._args,
        }

    def load_state_dict(self, state):
        assert self.owner == state["owner"]
        # Not loading args by design

        self._net.load_state_dict(state["q_net"])
        self._target_net.load_state_dict(state["target_net"])
        self._eps = state["eps"]

    @staticmethod
    def from_state_dict(state_dict, env_bldr):
        ddqn = DDQN(owner=state_dict["owner"],
                    ddqn_args=state_dict["args"],
                    env_bldr=env_bldr)
        ddqn.load_state_dict(state_dict)
        ddqn.update_target_net()
        return ddqn

    @staticmethod
    def inference_version_from_state_dict(state_dict, env_bldr):
        ddqn = DDQN.from_state_dict(state_dict=state_dict, env_bldr=env_bldr)
        ddqn.buf = None
        ddqn.eps = None
        return ddqn
 def _get_new_net(self):
     return DuelingQNet(q_args=self._args.ddqn_args.q_args,
                        env_bldr=self._eval_env_bldr,
                        device=self._device)
示例#8
0
class IterationStrategy:
    def __init__(self, t_prof, owner, env_bldr, device, cfr_iter):
        self._t_prof = t_prof
        self._owner = owner
        self._env_bldr = env_bldr
        self._device = device
        self._cfr_iter = cfr_iter

        self._adv_net = None
        self._all_range_idxs = torch.arange(self._env_bldr.rules.RANGE_SIZE,
                                            device=self._device,
                                            dtype=torch.long)

    @property
    def owner(self):
        return self._owner

    @property
    def cfr_iteration(self):
        return self._cfr_iter

    @property
    def device(self):
        return self._device

    def reset(self):
        self._adv_net = None

    def get_action(self, pub_obses, range_idxs, legal_actions_lists):
        a_probs = self.get_a_probs(pub_obses=pub_obses,
                                   range_idxs=range_idxs,
                                   legal_actions_lists=legal_actions_lists)

        return torch.multinomial(torch.from_numpy(a_probs),
                                 num_samples=1).cpu().numpy()

    def get_a_probs2(self,
                     pub_obses,
                     range_idxs,
                     legal_action_masks,
                     to_np=True):
        """
        Args:
            pub_obses (list):               batch (list) of np arrays of shape [np.arr([history_len, n_features]), ...)
            range_idxs (list):              batch (list) of range_idxs (one for each pub_obs) [2, 421, 58, 912, ...]
            legal_action_masks (Torch.tensor)
        """

        with torch.no_grad():
            bs = len(range_idxs)

            if self._adv_net is None:  # at iteration 0
                uniform_even_legal = legal_action_masks / (
                    legal_action_masks.sum(-1).unsqueeze(-1).expand_as(
                        legal_action_masks))
                if to_np:
                    return uniform_even_legal.cpu().numpy()
                return uniform_even_legal
            else:
                range_idxs = torch.tensor(range_idxs,
                                          dtype=torch.long,
                                          device=self._device)
                pub_obses = torch.tensor(pub_obses,
                                         dtype=torch.float32,
                                         device=self._device)

                advantages = self._adv_net(
                    pub_obses=pub_obses,
                    range_idxs=range_idxs,
                    legal_action_masks=legal_action_masks)

                # """"""""""""""""""""
                relu_advantages = F.relu(
                    advantages, inplace=False
                )  # Cause the sum of *positive* regret matters in CFR
                sum_pos_adv_expanded = relu_advantages.sum(1).unsqueeze(
                    -1).expand_as(relu_advantages)

                # """"""""""""""""""""
                # In case all negative
                # """"""""""""""""""""
                best_legal_deterministic = torch.zeros((
                    bs,
                    self._env_bldr.N_ACTIONS,
                ),
                                                       dtype=torch.float32,
                                                       device=self._device)
                bests = torch.argmax(torch.where(
                    legal_action_masks.byte(), advantages,
                    torch.full_like(advantages, fill_value=-10e20)),
                                     dim=1)
                _batch_arranged = torch.arange(bs,
                                               device=self._device,
                                               dtype=torch.long)
                best_legal_deterministic[_batch_arranged, bests] = 1

                # """"""""""""""""""""
                # Strat
                # """"""""""""""""""""
                strategy = torch.where(sum_pos_adv_expanded > 0,
                                       relu_advantages / sum_pos_adv_expanded,
                                       best_legal_deterministic)

                if to_np:
                    strategy = strategy.cpu().numpy()
                return strategy

    def get_a_probs(self,
                    pub_obses,
                    range_idxs,
                    legal_actions_lists,
                    to_np=True):
        """
        Args:
            pub_obses (list):               batch (list) of np arrays of shape [np.arr([history_len, n_features]), ...)
            range_idxs (list):              batch (list) of range_idxs (one for each pub_obs) [2, 421, 58, 912, ...]
            legal_actions_lists (list):     batch (list) of lists of integers that represent legal actions
        """

        with torch.no_grad():
            masks = rl_util.batch_get_legal_action_mask_torch(
                n_actions=self._env_bldr.N_ACTIONS,
                legal_actions_lists=legal_actions_lists,
                device=self._device,
                dtype=torch.float32)
            return self.get_a_probs2(pub_obses=pub_obses,
                                     range_idxs=range_idxs,
                                     legal_action_masks=masks,
                                     to_np=to_np)

    def get_a_probs_for_each_hand(self, pub_obs, legal_actions_list):
        """
        Args:
            pub_obs (np.array(shape=(seq_len, n_features,)))
            legal_actions_list (list):      list of ints representing legal actions
        """

        if self._t_prof.DEBUGGING:
            assert isinstance(pub_obs, np.ndarray)
            assert len(
                pub_obs.shape) == 2, "all hands have the same public obs"
            assert isinstance(
                legal_actions_list[0], int
            ), "all hands do the same actions. no need to batch, just parse int"

        return self._get_a_probs_of_hands(
            pub_obs=pub_obs,
            legal_actions_list=legal_actions_list,
            range_idxs_tensor=self._all_range_idxs)

    def get_a_probs_for_each_hand_in_list(self, pub_obs, range_idxs,
                                          legal_actions_list):
        """
        Args:
            pub_obs (np.array(shape=(seq_len, n_features,)))
            range_idxs (np.ndarray):        list of range_idxs to evaluate in public state ""pub_obs""
            legal_actions_list (list):      list of ints representing legal actions
        """

        if self._t_prof.DEBUGGING:
            assert isinstance(pub_obs, np.ndarray)
            assert isinstance(range_idxs, np.ndarray)
            assert len(
                pub_obs.shape) == 2, "all hands have the same public obs"
            assert isinstance(
                legal_actions_list[0],
                int), "all hands can do the same actions. no need to batch"

        return self._get_a_probs_of_hands(
            pub_obs=pub_obs,
            legal_actions_list=legal_actions_list,
            range_idxs_tensor=torch.from_numpy(range_idxs).to(
                dtype=torch.long, device=self._device))

    def _get_a_probs_of_hands(self, pub_obs, range_idxs_tensor,
                              legal_actions_list):
        with torch.no_grad():
            n_hands = range_idxs_tensor.size(0)

            if self._adv_net is None:  # at iteration 0
                uniform_even_legal = torch.zeros((self._env_bldr.N_ACTIONS, ),
                                                 dtype=torch.float32,
                                                 device=self._device)
                uniform_even_legal[legal_actions_list] = 1.0 / len(
                    legal_actions_list)  # always >0
                uniform_even_legal = uniform_even_legal.unsqueeze(0).expand(
                    n_hands, self._env_bldr.N_ACTIONS)
                return uniform_even_legal.cpu().numpy()

            else:
                legal_action_masks = rl_util.get_legal_action_mask_torch(
                    n_actions=self._env_bldr.N_ACTIONS,
                    legal_actions_list=legal_actions_list,
                    device=self._device,
                    dtype=torch.float32)
                legal_action_masks = legal_action_masks.unsqueeze(0).expand(
                    n_hands, -1)

                advantages = self._adv_net(
                    pub_obses=[pub_obs] * n_hands,
                    range_idxs=range_idxs_tensor,
                    legal_action_masks=legal_action_masks)

                # """"""""""""""""""""
                relu_advantages = F.relu(
                    advantages, inplace=False
                )  # Cause the sum of *positive* regret matters in CFR
                sum_pos_adv_expanded = relu_advantages.sum(1).unsqueeze(
                    -1).expand_as(relu_advantages)

                # """"""""""""""""""""
                # In case all negative
                # """"""""""""""""""""
                best_legal_deterministic = torch.zeros((
                    n_hands,
                    self._env_bldr.N_ACTIONS,
                ),
                                                       dtype=torch.float32,
                                                       device=self._device)
                bests = torch.argmax(torch.where(
                    legal_action_masks.byte(), advantages,
                    torch.full_like(advantages, fill_value=-10e20)),
                                     dim=1)

                _batch_arranged = torch.arange(n_hands,
                                               device=self._device,
                                               dtype=torch.long)
                best_legal_deterministic[_batch_arranged, bests] = 1

                # """"""""""""""""""""
                # Strategy
                # """"""""""""""""""""
                strategy = torch.where(
                    sum_pos_adv_expanded > 0,
                    relu_advantages / sum_pos_adv_expanded,
                    best_legal_deterministic,
                )

                return strategy.cpu().numpy()

    def state_dict(self):
        return {
            "owner": self._owner,
            "net": self.net_state_dict(),
            "iter": self._cfr_iter,
        }

    @staticmethod
    def build_from_state_dict(t_prof, env_bldr, device, state):
        s = IterationStrategy(t_prof=t_prof,
                              env_bldr=env_bldr,
                              device=device,
                              owner=state["owner"],
                              cfr_iter=state["iter"])
        s.load_state_dict(state=state)  # loads net state
        return s

    def load_state_dict(self, state):
        assert self._owner == state["owner"]
        self.load_net_state_dict(state["net"])
        self._cfr_iter = state["iter"]

    def net_state_dict(self):
        """ This just wraps the net.state_dict() with the option of returning None if net is None """
        if self._adv_net is None:
            return None
        return self._adv_net.state_dict()

    def load_net_state_dict(self, state_dict):
        if state_dict is None:
            return  # if this happens (should only for iteration 0), this class will return random actions.
        else:
            self._adv_net = DuelingQNet(
                q_args=self._t_prof.module_args["adv_training"].adv_net_args,
                env_bldr=self._env_bldr,
                device=self._device)
            if self._t_prof.module_args[
                    'adv_training'].init_adv_model == "last":
                # load net inner parameters only if we use it
                self._adv_net.load_state_dict(state_dict)
            self._adv_net.to(self._device)

        self._adv_net.eval()
        for param in self._adv_net.parameters():
            param.requires_grad = False

    def get_copy(self, device=None):
        _device = self._device if device is None else device
        return IterationStrategy.build_from_state_dict(t_prof=self._t_prof,
                                                       env_bldr=self._env_bldr,
                                                       device=_device,
                                                       state=self.state_dict())
示例#9
0
 def _get_new_baseline_net(self):
     return DuelingQNet(q_args=self._baseline_args.q_net_args,
                        env_bldr=self._env_bldr,
                        device=self._device)
示例#10
0
 def _get_new_adv_net(self):
     return DuelingQNet(q_args=self._adv_args.adv_net_args,
                        env_bldr=self._env_bldr,
                        device=self._device)
示例#11
0
class ParameterServer(_ParameterServerBase):
    def __init__(self, t_prof, seat_id, chief_handle):
        self.ddqn_args = t_prof.module_args["ddqn"]
        self.avg_args = t_prof.module_args["avg"]
        super().__init__(t_prof=t_prof, chief_handle=chief_handle)

        self.seat_id = seat_id
        self.global_iter_id = 0

        self.eps = self.ddqn_args.eps_start

        self.q_net = DuelingQNet(q_args=self.ddqn_args.q_args,
                                 env_bldr=self._env_bldr,
                                 device=self._device)
        self.avg_net = AvrgStrategyNet(
            avrg_net_args=self.avg_args.avg_net_args,
            env_bldr=self._env_bldr,
            device=self._device)

        self.br_optim = rl_util.str_to_optim_cls(self.ddqn_args.optim_str)(
            self.q_net.parameters(), lr=self.ddqn_args.lr)
        self.avg_optim = rl_util.str_to_optim_cls(self.avg_args.optim_str)(
            self.avg_net.parameters(), lr=self.avg_args.lr)

        self.eps_exp = self._ray.remote(
            self._chief_handle.create_experiment,
            t_prof.name + ": epsilon Plyr" + str(seat_id))
        self._log_eps()

    # ______________________________________________ API to pull from PS _______________________________________________
    def get_avg_weights(self):
        self.avg_net.zero_grad()
        return self._ray.state_dict_to_numpy(self.avg_net.state_dict())

    def get_q1_weights(self):
        self.q_net.zero_grad()
        return self._ray.state_dict_to_numpy(self.q_net.state_dict())

    def get_eps(self):
        return self.eps

    def _log_eps(self):
        self._ray.remote(self._chief_handle.add_scalar, self.eps_exp,
                         "Epsilon", self.global_iter_id, self.eps)

    # ____________________________________________ API to make PS compute ______________________________________________
    def apply_grads_br(self, list_grads):
        self._apply_grads(list_of_grads=list_grads,
                          optimizer=self.br_optim,
                          net=self.q_net,
                          grad_norm_clip=self.ddqn_args.grad_norm_clipping)

    def apply_grads_avg(self, list_grads):
        self._apply_grads(list_of_grads=list_grads,
                          optimizer=self.avg_optim,
                          net=self.avg_net,
                          grad_norm_clip=self.avg_args.grad_norm_clipping)

    def increment(self):
        self.global_iter_id += 1

        self.eps = rl_util.polynomial_decay(
            base=self.ddqn_args.eps_start,
            const=self.ddqn_args.eps_const,
            exponent=self.ddqn_args.eps_exponent,
            minimum=self.ddqn_args.eps_min,
            counter=self.global_iter_id)

        self._log_eps()
        return self.seat_id

    # ______________________________________________ API for checkpointing _____________________________________________
    def checkpoint(self, curr_step):
        state = {
            "seat_id": self.seat_id,
            "global_iter_id": self.global_iter_id,
            "eps": self.eps,
            "q_net": self.q_net.state_dict(),
            "avg_net": self.avg_net.state_dict(),
            "br_optim": self.br_optim.state_dict(),
            "avg_optim": self.avg_optim.state_dict(),
        }

        with open(
                self._get_checkpoint_file_path(
                    name=self._t_prof.name,
                    step=curr_step,
                    cls=self.__class__,
                    worker_id="P" + str(self.seat_id)), "wb") as pkl_file:
            pickle.dump(obj=state,
                        file=pkl_file,
                        protocol=pickle.HIGHEST_PROTOCOL)

    def load_checkpoint(self, name_to_load, step):
        with open(
                self._get_checkpoint_file_path(
                    name=name_to_load,
                    step=step,
                    cls=self.__class__,
                    worker_id="P" + str(self.seat_id)), "rb") as pkl_file:
            state = pickle.load(pkl_file)

            assert self.seat_id == state["seat_id"]

            self.eps = state["eps"]
            self.global_iter_id = state["global_iter_id"]
            self.q_net.load_state_dict(state["q_net"])
            self.avg_net.load_state_dict(state["avg_net"])
            self.br_optim.load_state_dict(state["br_optim"])
            self.avg_optim.load_state_dict(state["avg_optim"])