def __init__(self, t_prof, seat_id, chief_handle): self.ddqn_args = t_prof.module_args["ddqn"] self.avg_args = t_prof.module_args["avg"] super().__init__(t_prof=t_prof, chief_handle=chief_handle) self.seat_id = seat_id self.global_iter_id = 0 self.eps = self.ddqn_args.eps_start self.q_net = DuelingQNet(q_args=self.ddqn_args.q_args, env_bldr=self._env_bldr, device=self._device) self.avg_net = AvrgStrategyNet( avrg_net_args=self.avg_args.avg_net_args, env_bldr=self._env_bldr, device=self._device) self.br_optim = rl_util.str_to_optim_cls(self.ddqn_args.optim_str)( self.q_net.parameters(), lr=self.ddqn_args.lr) self.avg_optim = rl_util.str_to_optim_cls(self.avg_args.optim_str)( self.avg_net.parameters(), lr=self.avg_args.lr) self.eps_exp = self._ray.remote( self._chief_handle.create_experiment, t_prof.name + ": epsilon Plyr" + str(seat_id)) self._log_eps()
def __init__(self, env_bldr, ddqn_args, owner, ): super().__init__( net=DuelingQNet(env_bldr=env_bldr, q_args=ddqn_args.q_args, device=ddqn_args.device_training), env_bldr=env_bldr, args=ddqn_args, owner=owner, device=ddqn_args.device_training, ) self._eps = None self._target_net = DuelingQNet(env_bldr=env_bldr, q_args=ddqn_args.q_args, device=ddqn_args.device_training) self._target_net.eval() self.update_target_net() self._batch_arranged = torch.arange(ddqn_args.batch_size, dtype=torch.long, device=self.device) self._minus_e20 = torch.full((ddqn_args.batch_size, self._env_bldr.N_ACTIONS,), fill_value=-10e20, device=self.device, dtype=torch.float32, requires_grad=False) self._n_actions_arranged = np.arange(self._env_bldr.N_ACTIONS, dtype=np.int32).tolist()
def load_net_state_dict(self, state_dict): if state_dict is None: return # if this happens (should only for iteration 0), this class will return random actions. else: self._adv_net = DuelingQNet(q_args=self._t_prof.module_args["adv_training"].adv_net_args, env_bldr=self._env_bldr, device=self._device) self._adv_net.load_state_dict(state_dict) self._adv_net.to(self._device) self._adv_net.eval() for param in self._adv_net.parameters(): param.requires_grad = False
def __init__(self, env_bldr, adv_training_args, owner, device): super().__init__( net=DuelingQNet(env_bldr=env_bldr, q_args=adv_training_args.adv_net_args, device=device), env_bldr=env_bldr, args=adv_training_args, owner=owner, device=device )
def __init__(self, env_bldr, baseline_args): super().__init__( net=DuelingQNet(env_bldr=env_bldr, q_args=baseline_args.q_net_args, device=baseline_args.device_training), owner=None, env_bldr=env_bldr, args=baseline_args, device=baseline_args.device_training, ) self._batch_arranged = torch.arange(self._args.batch_size, dtype=torch.long, device=self.device) self._minus_e20 = torch.full((self._args.batch_size, self._env_bldr.N_ACTIONS,), fill_value=-10e20, device=self.device, dtype=torch.float32, requires_grad=False)
class DDQN(_NetWrapperBase): def __init__(self, env_bldr, ddqn_args, owner, ): super().__init__( net=DuelingQNet(env_bldr=env_bldr, q_args=ddqn_args.q_args, device=ddqn_args.device_training), env_bldr=env_bldr, args=ddqn_args, owner=owner, device=ddqn_args.device_training, ) self._eps = None self._target_net = DuelingQNet(env_bldr=env_bldr, q_args=ddqn_args.q_args, device=ddqn_args.device_training) self._target_net.eval() self.update_target_net() self._batch_arranged = torch.arange(ddqn_args.batch_size, dtype=torch.long, device=self.device) self._minus_e20 = torch.full((ddqn_args.batch_size, self._env_bldr.N_ACTIONS,), fill_value=-10e20, device=self.device, dtype=torch.float32, requires_grad=False) self._n_actions_arranged = np.arange(self._env_bldr.N_ACTIONS, dtype=np.int32).tolist() @property def eps(self): return self._eps @eps.setter def eps(self, value): self._eps = value def select_br_a(self, pub_obses, range_idxs, legal_actions_lists, explore=False): if explore and (np.random.random() < self._eps): return np.array( [legal_actions[np.random.randint(len(legal_actions))] for legal_actions in legal_actions_lists] ) with torch.no_grad(): self.eval() range_idxs = torch.tensor(range_idxs, dtype=torch.long, device=self.device) q = self._net(pub_obses=pub_obses, range_idxs=range_idxs, legal_action_masks=rl_util.batch_get_legal_action_mask_torch( n_actions=self._env_bldr.N_ACTIONS, legal_actions_lists=legal_actions_lists, device=self.device, dtype=torch.float32)).cpu().numpy() for b in range(q.shape[0]): illegal_actions = [i for i in self._n_actions_arranged if i not in legal_actions_lists[b]] if len(illegal_actions) > 0: illegal_actions = np.array(illegal_actions) q[b, illegal_actions] = -1e20 return np.argmax(q, axis=1) def update_target_net(self): self._target_net.load_state_dict(self._net.state_dict()) self._target_net.eval() def _mini_batch_loop(self, buffer, grad_mngr): batch_pub_obs_t, \ batch_a_t, \ batch_range_idx, \ batch_legal_action_mask_t, \ batch_r_t, \ batch_pub_obs_tp1, \ batch_legal_action_mask_tp1, \ batch_done = \ buffer.sample(device=self.device, batch_size=self._args.batch_size) # [batch_size, n_actions] q1_t = self._net(pub_obses=batch_pub_obs_t, range_idxs=batch_range_idx, legal_action_masks=batch_legal_action_mask_t.to(torch.float32)) q1_tp1 = self._net(pub_obses=batch_pub_obs_tp1, range_idxs=batch_range_idx, legal_action_masks=batch_legal_action_mask_tp1.to(torch.float32)).detach() q2_tp1 = self._target_net(pub_obses=batch_pub_obs_tp1, range_idxs=batch_range_idx, legal_action_masks=batch_legal_action_mask_tp1.to(torch.float32)).detach() # ______________________________________________ TD Learning _______________________________________________ # [batch_size] q1_t_of_a_selected = q1_t[self._batch_arranged, batch_a_t] # only consider allowed actions for tp1 q1_tp1 = torch.where(batch_legal_action_mask_tp1, q1_tp1, self._minus_e20) # [batch_size] _, best_a_tp1 = q1_tp1.max(dim=-1, keepdim=False) q2_best_a_tp1 = q2_tp1[self._batch_arranged, best_a_tp1] q2_best_a_tp1 = q2_best_a_tp1 * (1.0 - batch_done) target = batch_r_t + q2_best_a_tp1 grad_mngr.backprop(pred=q1_t_of_a_selected, target=target) def state_dict(self): return { "q_net": self._net.state_dict(), "target_net": self._target_net.state_dict(), "eps": self._eps, "owner": self.owner, "args": self._args, } def load_state_dict(self, state): assert self.owner == state["owner"] # Not loading args by design self._net.load_state_dict(state["q_net"]) self._target_net.load_state_dict(state["target_net"]) self._eps = state["eps"] @staticmethod def from_state_dict(state_dict, env_bldr): ddqn = DDQN(owner=state_dict["owner"], ddqn_args=state_dict["args"], env_bldr=env_bldr) ddqn.load_state_dict(state_dict) ddqn.update_target_net() return ddqn @staticmethod def inference_version_from_state_dict(state_dict, env_bldr): ddqn = DDQN.from_state_dict(state_dict=state_dict, env_bldr=env_bldr) ddqn.buf = None ddqn.eps = None return ddqn
def _get_new_net(self): return DuelingQNet(q_args=self._args.ddqn_args.q_args, env_bldr=self._eval_env_bldr, device=self._device)
class IterationStrategy: def __init__(self, t_prof, owner, env_bldr, device, cfr_iter): self._t_prof = t_prof self._owner = owner self._env_bldr = env_bldr self._device = device self._cfr_iter = cfr_iter self._adv_net = None self._all_range_idxs = torch.arange(self._env_bldr.rules.RANGE_SIZE, device=self._device, dtype=torch.long) @property def owner(self): return self._owner @property def cfr_iteration(self): return self._cfr_iter @property def device(self): return self._device def reset(self): self._adv_net = None def get_action(self, pub_obses, range_idxs, legal_actions_lists): a_probs = self.get_a_probs(pub_obses=pub_obses, range_idxs=range_idxs, legal_actions_lists=legal_actions_lists) return torch.multinomial(torch.from_numpy(a_probs), num_samples=1).cpu().numpy() def get_a_probs2(self, pub_obses, range_idxs, legal_action_masks, to_np=True): """ Args: pub_obses (list): batch (list) of np arrays of shape [np.arr([history_len, n_features]), ...) range_idxs (list): batch (list) of range_idxs (one for each pub_obs) [2, 421, 58, 912, ...] legal_action_masks (Torch.tensor) """ with torch.no_grad(): bs = len(range_idxs) if self._adv_net is None: # at iteration 0 uniform_even_legal = legal_action_masks / ( legal_action_masks.sum(-1).unsqueeze(-1).expand_as( legal_action_masks)) if to_np: return uniform_even_legal.cpu().numpy() return uniform_even_legal else: range_idxs = torch.tensor(range_idxs, dtype=torch.long, device=self._device) pub_obses = torch.tensor(pub_obses, dtype=torch.float32, device=self._device) advantages = self._adv_net( pub_obses=pub_obses, range_idxs=range_idxs, legal_action_masks=legal_action_masks) # """""""""""""""""""" relu_advantages = F.relu( advantages, inplace=False ) # Cause the sum of *positive* regret matters in CFR sum_pos_adv_expanded = relu_advantages.sum(1).unsqueeze( -1).expand_as(relu_advantages) # """""""""""""""""""" # In case all negative # """""""""""""""""""" best_legal_deterministic = torch.zeros(( bs, self._env_bldr.N_ACTIONS, ), dtype=torch.float32, device=self._device) bests = torch.argmax(torch.where( legal_action_masks.byte(), advantages, torch.full_like(advantages, fill_value=-10e20)), dim=1) _batch_arranged = torch.arange(bs, device=self._device, dtype=torch.long) best_legal_deterministic[_batch_arranged, bests] = 1 # """""""""""""""""""" # Strat # """""""""""""""""""" strategy = torch.where(sum_pos_adv_expanded > 0, relu_advantages / sum_pos_adv_expanded, best_legal_deterministic) if to_np: strategy = strategy.cpu().numpy() return strategy def get_a_probs(self, pub_obses, range_idxs, legal_actions_lists, to_np=True): """ Args: pub_obses (list): batch (list) of np arrays of shape [np.arr([history_len, n_features]), ...) range_idxs (list): batch (list) of range_idxs (one for each pub_obs) [2, 421, 58, 912, ...] legal_actions_lists (list): batch (list) of lists of integers that represent legal actions """ with torch.no_grad(): masks = rl_util.batch_get_legal_action_mask_torch( n_actions=self._env_bldr.N_ACTIONS, legal_actions_lists=legal_actions_lists, device=self._device, dtype=torch.float32) return self.get_a_probs2(pub_obses=pub_obses, range_idxs=range_idxs, legal_action_masks=masks, to_np=to_np) def get_a_probs_for_each_hand(self, pub_obs, legal_actions_list): """ Args: pub_obs (np.array(shape=(seq_len, n_features,))) legal_actions_list (list): list of ints representing legal actions """ if self._t_prof.DEBUGGING: assert isinstance(pub_obs, np.ndarray) assert len( pub_obs.shape) == 2, "all hands have the same public obs" assert isinstance( legal_actions_list[0], int ), "all hands do the same actions. no need to batch, just parse int" return self._get_a_probs_of_hands( pub_obs=pub_obs, legal_actions_list=legal_actions_list, range_idxs_tensor=self._all_range_idxs) def get_a_probs_for_each_hand_in_list(self, pub_obs, range_idxs, legal_actions_list): """ Args: pub_obs (np.array(shape=(seq_len, n_features,))) range_idxs (np.ndarray): list of range_idxs to evaluate in public state ""pub_obs"" legal_actions_list (list): list of ints representing legal actions """ if self._t_prof.DEBUGGING: assert isinstance(pub_obs, np.ndarray) assert isinstance(range_idxs, np.ndarray) assert len( pub_obs.shape) == 2, "all hands have the same public obs" assert isinstance( legal_actions_list[0], int), "all hands can do the same actions. no need to batch" return self._get_a_probs_of_hands( pub_obs=pub_obs, legal_actions_list=legal_actions_list, range_idxs_tensor=torch.from_numpy(range_idxs).to( dtype=torch.long, device=self._device)) def _get_a_probs_of_hands(self, pub_obs, range_idxs_tensor, legal_actions_list): with torch.no_grad(): n_hands = range_idxs_tensor.size(0) if self._adv_net is None: # at iteration 0 uniform_even_legal = torch.zeros((self._env_bldr.N_ACTIONS, ), dtype=torch.float32, device=self._device) uniform_even_legal[legal_actions_list] = 1.0 / len( legal_actions_list) # always >0 uniform_even_legal = uniform_even_legal.unsqueeze(0).expand( n_hands, self._env_bldr.N_ACTIONS) return uniform_even_legal.cpu().numpy() else: legal_action_masks = rl_util.get_legal_action_mask_torch( n_actions=self._env_bldr.N_ACTIONS, legal_actions_list=legal_actions_list, device=self._device, dtype=torch.float32) legal_action_masks = legal_action_masks.unsqueeze(0).expand( n_hands, -1) advantages = self._adv_net( pub_obses=[pub_obs] * n_hands, range_idxs=range_idxs_tensor, legal_action_masks=legal_action_masks) # """""""""""""""""""" relu_advantages = F.relu( advantages, inplace=False ) # Cause the sum of *positive* regret matters in CFR sum_pos_adv_expanded = relu_advantages.sum(1).unsqueeze( -1).expand_as(relu_advantages) # """""""""""""""""""" # In case all negative # """""""""""""""""""" best_legal_deterministic = torch.zeros(( n_hands, self._env_bldr.N_ACTIONS, ), dtype=torch.float32, device=self._device) bests = torch.argmax(torch.where( legal_action_masks.byte(), advantages, torch.full_like(advantages, fill_value=-10e20)), dim=1) _batch_arranged = torch.arange(n_hands, device=self._device, dtype=torch.long) best_legal_deterministic[_batch_arranged, bests] = 1 # """""""""""""""""""" # Strategy # """""""""""""""""""" strategy = torch.where( sum_pos_adv_expanded > 0, relu_advantages / sum_pos_adv_expanded, best_legal_deterministic, ) return strategy.cpu().numpy() def state_dict(self): return { "owner": self._owner, "net": self.net_state_dict(), "iter": self._cfr_iter, } @staticmethod def build_from_state_dict(t_prof, env_bldr, device, state): s = IterationStrategy(t_prof=t_prof, env_bldr=env_bldr, device=device, owner=state["owner"], cfr_iter=state["iter"]) s.load_state_dict(state=state) # loads net state return s def load_state_dict(self, state): assert self._owner == state["owner"] self.load_net_state_dict(state["net"]) self._cfr_iter = state["iter"] def net_state_dict(self): """ This just wraps the net.state_dict() with the option of returning None if net is None """ if self._adv_net is None: return None return self._adv_net.state_dict() def load_net_state_dict(self, state_dict): if state_dict is None: return # if this happens (should only for iteration 0), this class will return random actions. else: self._adv_net = DuelingQNet( q_args=self._t_prof.module_args["adv_training"].adv_net_args, env_bldr=self._env_bldr, device=self._device) if self._t_prof.module_args[ 'adv_training'].init_adv_model == "last": # load net inner parameters only if we use it self._adv_net.load_state_dict(state_dict) self._adv_net.to(self._device) self._adv_net.eval() for param in self._adv_net.parameters(): param.requires_grad = False def get_copy(self, device=None): _device = self._device if device is None else device return IterationStrategy.build_from_state_dict(t_prof=self._t_prof, env_bldr=self._env_bldr, device=_device, state=self.state_dict())
def _get_new_baseline_net(self): return DuelingQNet(q_args=self._baseline_args.q_net_args, env_bldr=self._env_bldr, device=self._device)
def _get_new_adv_net(self): return DuelingQNet(q_args=self._adv_args.adv_net_args, env_bldr=self._env_bldr, device=self._device)
class ParameterServer(_ParameterServerBase): def __init__(self, t_prof, seat_id, chief_handle): self.ddqn_args = t_prof.module_args["ddqn"] self.avg_args = t_prof.module_args["avg"] super().__init__(t_prof=t_prof, chief_handle=chief_handle) self.seat_id = seat_id self.global_iter_id = 0 self.eps = self.ddqn_args.eps_start self.q_net = DuelingQNet(q_args=self.ddqn_args.q_args, env_bldr=self._env_bldr, device=self._device) self.avg_net = AvrgStrategyNet( avrg_net_args=self.avg_args.avg_net_args, env_bldr=self._env_bldr, device=self._device) self.br_optim = rl_util.str_to_optim_cls(self.ddqn_args.optim_str)( self.q_net.parameters(), lr=self.ddqn_args.lr) self.avg_optim = rl_util.str_to_optim_cls(self.avg_args.optim_str)( self.avg_net.parameters(), lr=self.avg_args.lr) self.eps_exp = self._ray.remote( self._chief_handle.create_experiment, t_prof.name + ": epsilon Plyr" + str(seat_id)) self._log_eps() # ______________________________________________ API to pull from PS _______________________________________________ def get_avg_weights(self): self.avg_net.zero_grad() return self._ray.state_dict_to_numpy(self.avg_net.state_dict()) def get_q1_weights(self): self.q_net.zero_grad() return self._ray.state_dict_to_numpy(self.q_net.state_dict()) def get_eps(self): return self.eps def _log_eps(self): self._ray.remote(self._chief_handle.add_scalar, self.eps_exp, "Epsilon", self.global_iter_id, self.eps) # ____________________________________________ API to make PS compute ______________________________________________ def apply_grads_br(self, list_grads): self._apply_grads(list_of_grads=list_grads, optimizer=self.br_optim, net=self.q_net, grad_norm_clip=self.ddqn_args.grad_norm_clipping) def apply_grads_avg(self, list_grads): self._apply_grads(list_of_grads=list_grads, optimizer=self.avg_optim, net=self.avg_net, grad_norm_clip=self.avg_args.grad_norm_clipping) def increment(self): self.global_iter_id += 1 self.eps = rl_util.polynomial_decay( base=self.ddqn_args.eps_start, const=self.ddqn_args.eps_const, exponent=self.ddqn_args.eps_exponent, minimum=self.ddqn_args.eps_min, counter=self.global_iter_id) self._log_eps() return self.seat_id # ______________________________________________ API for checkpointing _____________________________________________ def checkpoint(self, curr_step): state = { "seat_id": self.seat_id, "global_iter_id": self.global_iter_id, "eps": self.eps, "q_net": self.q_net.state_dict(), "avg_net": self.avg_net.state_dict(), "br_optim": self.br_optim.state_dict(), "avg_optim": self.avg_optim.state_dict(), } with open( self._get_checkpoint_file_path( name=self._t_prof.name, step=curr_step, cls=self.__class__, worker_id="P" + str(self.seat_id)), "wb") as pkl_file: pickle.dump(obj=state, file=pkl_file, protocol=pickle.HIGHEST_PROTOCOL) def load_checkpoint(self, name_to_load, step): with open( self._get_checkpoint_file_path( name=name_to_load, step=step, cls=self.__class__, worker_id="P" + str(self.seat_id)), "rb") as pkl_file: state = pickle.load(pkl_file) assert self.seat_id == state["seat_id"] self.eps = state["eps"] self.global_iter_id = state["global_iter_id"] self.q_net.load_state_dict(state["q_net"]) self.avg_net.load_state_dict(state["avg_net"]) self.br_optim.load_state_dict(state["br_optim"]) self.avg_optim.load_state_dict(state["avg_optim"])