def extract_state(self, state): obs = np.zeros((7, 4, 15), dtype=int) encode_hand(obs[:3], state['hand']) encode_target(obs[3], state['target']) encode_hand(obs[4:], state['others_hand']) legal_action_id = self.get_legal_actions() extrated_state = {'obs': obs, 'legal_actions': legal_action_id} return extrated_state
def _extract_state(self, state): obs = np.zeros((4, 4, 15), dtype=int) encode_hand(obs[:3], state['hand']) encode_target(obs[3], state['target']) legal_action_id = self._get_legal_actions() extracted_state = {'obs': obs, 'legal_actions': legal_action_id} extracted_state['raw_obs'] = state extracted_state['raw_legal_actions'] = [ a for a in state['legal_actions'] ] extracted_state['action_record'] = self.action_recorder return extracted_state
def _extract_state(self, state): obs = np.zeros((7, 4, 15), dtype=int) encode_hand(obs[:3], state['hand']) encode_target(obs[3], state['target']) encode_hand(obs[4:], state['others_hand']) legal_action_id = self._get_legal_actions() extracted_state = {'obs': obs, 'legal_actions': legal_action_id} if self.allow_raw_data: extracted_state['raw_obs'] = state extracted_state['raw_legal_actions'] = [a for a in state['legal_actions']] if self.record_action: extracted_state['action_record'] = self.action_recorder return extracted_state
def test_encode_hand(self): hand1 = ['y-1', 'r-8', 'b-9', 'y-reverse', 'r-skip'] encoded_hand1 = np.zeros((3, 4, 15), dtype=int) encode_hand(encoded_hand1, hand1) for index in range(15): total = 0 for color in range(4): total += encoded_hand1[0][color][index] + encoded_hand1[1][ color][index] + encoded_hand1[2][color][index] self.assertEqual(total, 4) hand2 = ['r-wild', 'g-wild_draw_4'] encoded_hand2 = np.zeros((3, 4, 15), dtype=int) encode_hand(encoded_hand2, hand2) for color in range(4): self.assertEqual(encoded_hand2[1][color][-2], 1) self.assertEqual(encoded_hand2[1][color][-1], 1)
def _extract_state(self, state): obs = np.zeros((3, 4, 15), dtype=int) targetEncoding = np.zeros([19]) encode_hand(obs[:3], state['hand']) encode_target(targetEncoding, state['target']) obs = obs[1:].flatten() next_player_amount = [0, 0, 0, 0] nextAmountOfCards = state['next_player_amount'] if nextAmountOfCards >= 4: next_player_amount[3] = 1 else: next_player_amount[nextAmountOfCards - 1] = 1 obs = np.concatenate((obs, next_player_amount, targetEncoding)) # encode_hand(obs[4:], state['others_hand']) legal_action_id = self._get_legal_actions() extracted_state = {'obs': obs, 'legal_actions': legal_action_id} if self.allow_raw_data: extracted_state['raw_obs'] = state extracted_state['raw_legal_actions'] = [ a for a in state['legal_actions'] ] if self.record_action: extracted_state['action_record'] = self.action_recorder return extracted_state