def batch_learning_update(actor, critic, target_actor, target_critic, params): mongo = MongoDB() actor.train() query = {'training_round': params['training_round']} projection = { 'obs': 1, 'state': 1, 'betsize_mask': 1, 'action_mask': 1, 'action': 1, 'reward': 1, '_id': 0 } db_data = mongo.get_data(query, projection) trainloader = return_trajectoryloader(db_data) for _ in range(params['learning_rounds']): losses = [] for i, data in enumerate(trainloader): critic_loss = update_actor_critic_batch(data, actor, critic, target_actor, target_critic, params) losses.append(critic_loss) # print(f'Learning Round {i}, critic loss {sum(losses)}, policy loss {sum(policy_losses)}') mongo.close() return actor, critic, params
def dual_learning_update(actor, critic, target_actor, target_critic, params, rank): mongo = MongoDB() actor.train() query = {'training_round': params['training_round'], 'rank': rank} projection = { 'obs': 1, 'state': 1, 'betsize_mask': 1, 'action_mask': 1, 'action': 1, 'reward': 1, '_id': 0 } data = mongo.get_data(query, projection) for i in range(params['learning_rounds']): policy_losses = [] losses = [] for poker_round in data: update_actor_critic(poker_round, critic, target_critic, actor, target_actor, params) soft_update(critic, target_critic, params['device']) soft_update(actor, target_actor, params['device']) mongo.close() del data return actor, critic, params
def plot_critic_values(training_round=0): query = { # 'position':args.position, # 'training_round':args.run } projection = {'values': 1, 'reward': 1, 'action': 1, '_id': 0} mongo = MongoDB() # for position in [pdt.PositionStrs.SB,pdt.PositionStrs.BB]: # query['position'] = position data = mongo.get_data(query, projection) rewards = [] actions = [] values = [] for point in data: rewards.append(point['reward']) values.append(point['values']) actions.append(point['action']) M = len(values) # plot value loss over time interval = M // 4 values = np.vstack(values) rewards = np.vstack(rewards) actions = np.array(actions) mask = np.zeros((actions.size, pdt.Action.RAISE), dtype=bool) mask[np.arange(actions.size), actions] = 1 critic_loss = values[mask].reshape(M, 1) - rewards critic_loss_rolling_mean = [] for i in range(len(critic_loss) - interval): critic_loss_rolling_mean.append(np.mean(critic_loss[i:interval + i])) plot_data(f'Critic loss ', [critic_loss_rolling_mean], ['Values'])
def plot_betsize_probabilities(training_round=0): query = {'training_round': training_round} projection = {'betsizes': 1, 'hand': 1, '_id': 0} params = {'interval': 100} mongo = MongoDB() gametype = "Omaha" # SB for position in [pdt.PositionStrs.SB, pdt.PositionStrs.BB]: query['position'] = position data = mongo.get_data(query, projection) betsize, unique_hands, unique_betsize = mongo.betsizeByHand( data, params) hand_labels = [ f'Hand {pdt.Globals.KUHN_CARD_DICT[hand]}' for hand in unique_hands ] action_labels = [size for size in unique_betsize] plot_frequencies( f'{gametype}_betsize_probabilities_for_{query["position"]}', betsize, hand_labels, action_labels)
def plot_action_frequencies(actiontype, handtype, training_round=0): print(actiontype, handtype) query = {'training_round': training_round} projection = {'action': 1, 'hand_strength': 1, 'hand': 1, '_id': 0} data_params = {'interval': 100} mongo = MongoDB() # gametype = mongo.get_gametype(training_round) gametype = "Omaha" for position in [pdt.PositionStrs.SB, pdt.PositionStrs.BB]: query['position'] = position data = mongo.get_data(query, projection) if handtype == pdt.VisualHandTypes.HAND: actions, hands, unique_actions = mongo.actionByHand( data, data_params) else: actions, hands, unique_actions = mongo.actionByHandStrength( data, data_params) hand_labels = HAND_LABELS_DICT[actiontype](hands) action_labels = [pdt.ACTION_DICT[act] for act in unique_actions] plot_frequencies( f'{gametype}_action_{handtype}_for_{query["position"]}', actions, hand_labels, action_labels)