def test_accuracy_encodings(): """Test that the encoding is correct for dice targetted by an accuracy.""" agent = LearningAgent() attacker = ship.Ship(name="Attacker", template=ship_templates["Attacker"], upgrades=[], player_number=1) three_brace = ship.Ship(name="Double Brace", template=ship_templates["Triple Brace"], upgrades=[], player_number=2) # Make a brace token red three_brace.spend_token('brace', ArmadaTypes.green) enc_three_brace, world_state = make_encoding(attacker, three_brace, "short", agent) # Define the offsets for convenience token_begin = Encodings.getAttackTokenOffset() token_end = token_begin + ArmadaTypes.max_defense_tokens # Verify that no tokens are targeted at first assert 0.0 == enc_three_brace[token_begin:token_end].sum() # Now make a token red and target it three_brace.spend_token('brace', ArmadaTypes.green) green_acc_begin = Encodings.getAttackTokenOffset() green_acc_end = green_acc_begin + len(ArmadaTypes.defense_tokens) red_acc_begin = Encodings.getAttackTokenOffset() + len( ArmadaTypes.defense_tokens) red_acc_end = red_acc_begin + len(ArmadaTypes.defense_tokens) world_state.attack.accuracy_defender_token( ArmadaTypes.defense_tokens.index('brace'), ArmadaTypes.red) encoding = Encodings.encodeAttackState(world_state) # Verify that only the red token has the accuracy flag set assert encoding[red_acc_begin + ArmadaTypes.defense_tokens.index('brace')].item() == 1. assert encoding[red_acc_begin:red_acc_end].sum().item() == 1. assert encoding[green_acc_begin:green_acc_end].sum().item() == 0. # Target both remaining green tokens world_state.attack.accuracy_defender_token( ArmadaTypes.defense_tokens.index('brace'), ArmadaTypes.green) world_state.attack.accuracy_defender_token( ArmadaTypes.defense_tokens.index('brace'), ArmadaTypes.green) encoding = Encodings.encodeAttackState(world_state) # Verify that two green and one red brace have the accuracy flag assert encoding[red_acc_begin + ArmadaTypes.defense_tokens.index('brace')].item() == 1. assert encoding[red_acc_begin:red_acc_end].sum().item() == 1. assert encoding[green_acc_begin + ArmadaTypes.defense_tokens.index('brace')].item() == 2. assert encoding[green_acc_begin:green_acc_end].sum().item() == 2.
def test_spent_encodings(): """Test that the encoding is correct for different defense tokens.""" agent = LearningAgent() attacker = ship.Ship(name="Attacker", template=ship_templates["Attacker"], upgrades=[], player_number=1) defender = ship.Ship(name="Defender", template=ship_templates["All Defense Tokens"], upgrades=[], player_number=2) encoding, world_state = make_encoding(attacker, defender, "short", agent) # The defender and attacker come first, then the accuracied tokens, then the spent tokens spent_begin = 2 * ship.Ship.encodeSize() + 2 * len( ArmadaTypes.defense_tokens) spent_end = spent_begin + len(ArmadaTypes.defense_tokens) # Verify that no tokens are marked spent by default assert torch.sum(encoding[spent_begin:spent_end]) == 0. # Spend all of the tokens for tidx, ttype in enumerate(ArmadaTypes.defense_tokens): world_state.attack.defender_spend_token(ttype, 'green') encoding = Encodings.encodeAttackState(world_state) assert torch.sum(encoding[spent_begin:spent_end]).item() == len( ArmadaTypes.defense_tokens) # Try spending the tokens at different indices for tidx, ttype in enumerate(ArmadaTypes.defense_tokens): # Re-encode and then set the token to spent. attacker = ship.Ship(name="Attacker", template=ship_templates["Attacker"], upgrades=[], player_number=1) defender = ship.Ship(name="Defender", template=ship_templates["All Defense Tokens"], upgrades=[], player_number=2) encoding, world_state = make_encoding(attacker, defender, "short", agent) world_state.attack.defender_spend_token(ttype, 'green') encoding = Encodings.encodeAttackState(world_state) assert torch.sum(encoding[spent_begin:spent_end]).item() == 1.0 assert encoding[spent_begin:spent_end][tidx].item() == 1.0
def make_encoding(ship_a, ship_b, attack_range, agent): """This function calculates the average time to destruction when a shoots at b. Args: ship_a ((Ship, str)): Attacker and hull zone tuple. ship_b ((Ship, str)): Defender and hull zone tuple. trials (int): Number of trials in average calculation. range (str): Attack range. """ roll_counts = [] # Reset ship b for each trial world_state = WorldState() world_state.addShip(ship_a, 0) world_state.addShip(ship_b, 1) pool_colors, pool_faces = ship_a.roll("front", attack_range) attack = AttackState(attack_range=attack_range, attacker=ship_a, attacking_hull="front", defender=ship_b, defending_hull="front", pool_colors=pool_colors, pool_faces=pool_faces) world_state.updateAttack(attack) # The defense token and die locations have been reordered in the encoding, put them back to # their original ordering here. encoding = Encodings.encodeAttackState(world_state) return encoding, world_state
def test_die_encodings(): """Test the dice are encoded correctly in the attack state.""" agent = RandomAgent() attacker = ship.Ship(name="Attacker", template=ship_templates["Attacker"], upgrades=[], player_number=1) template_front_dice = 0 for color in ['Red', 'Blue', 'Black']: if 0 < len(ship_templates["Attacker"][f"Armament Front {color}"]): template_front_dice += int( ship_templates["Attacker"][f"Armament Front {color}"]) dice_begin = ArmadaTypes.hull_zones.index('front') * len( ArmadaDice.die_colors) dice_end = dice_begin + len(ArmadaDice.die_colors) front_dice = int(attacker.get_range('dice')[dice_begin:dice_end].sum()) # The ship encoding should have the same dice as the template assert front_dice == template_front_dice defender = ship.Ship(name="Defender", template=ship_templates["All Defense Tokens"], upgrades=[], player_number=2) encoding, world_state = make_encoding(attacker, defender, "short", agent) attack_state_encoding = Encodings.encodeAttackState(world_state) die_offset = Encodings.getAttackDiceOffset() # The attack state should have a roll with as many dice as the ship has. dice_encoding = attack_state_encoding[die_offset:die_offset + Encodings.dieEncodingSize()] assert int(dice_encoding.sum().item()) == front_dice
def test_roll_encodings(): """Test that the encoding is correct for dice pools and faces.""" agent = LearningAgent() attacker = ship.Ship(name="Attacker", template=ship_templates["Attacker"], upgrades=[], player_number=1) no_token = ship.Ship(name="No Defense Tokens", template=ship_templates["No Defense Tokens"], upgrades=[], player_number=2) dice_begin = Encodings.getAttackDiceOffset() # Do 100 trials to ensure everything is working as expected _, world_state = make_encoding(attacker, no_token, "short", agent) for _ in range(100): pool_colors, pool_faces = attacker.roll("front", "short") attack = world_state.attack attack.pool_faces = pool_faces attack.pool_colors = pool_colors # Count which items are matched to check if they are all encoded matched_dice = [0] * len(pool_faces) world_state.updateAttack(attack) # Make a random roll and encode the attack state # [ color - 3, face - 6] enc_attack = Encodings.encodeAttackState(world_state) # Try to find a match for each color,face pair in the encoding enc_dice = enc_attack[Encodings.getAttackDiceOffset():] for face, color in zip(attack.pool_faces, attack.pool_colors): assert 0. < enc_dice[Encodings.dieOffset(color=color, face=face)].item() enc_dice[Encodings.dieOffset(color=color, face=face)] -= 1 # All dice from the pool should have been matched and there should be no more encoded assert sum(enc_dice) <= 0.
def test_defense_tokens_model(spend_defense_tokens_model): """Test basic network learning loop. Test lifetime predictions during the spend defense tokens phase. """ network, errors, eval_errors = spend_defense_tokens_model network.eval() phase_name = "attack - spend defense tokens" world_size = Encodings.calculateWorldStateSize() action_size = Encodings.calculateActionSize(phase_name) attack_size = Encodings.calculateAttackSize() input_size = world_size + action_size + attack_size target_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") batch_size = 32 batch = torch.Tensor(batch_size, input_size).to(target_device) print("Eval errors for B were {}".format(eval_errors)) #print("First and last errors in b are {} and {}".format(eval_errors[0], eval_errors[-1])) assert eval_errors[0] > eval_errors[-1] # Let's examine predictions for different dice pools and spent defense tokens. # Go through the following scenarios: # 1.1 An attack with more than enough damage to destroy the ship # 1.2 The same attack but a brace that would prevent destruction # 1.3 The same attack but a redirect that would prevent destruction # Result: 1.1 should have lower lifetime than 1.2 and 1.3 # 2.1 An attack that can barely destroy the ship # 2.2 An attack that barely will not destroy the ship # Result: 2.1 should have lower lifetime than 2.2. # Ideally 1.1 and 2.1 would predict the current round. world_state = WorldState() world_state.round = 1 ship_a = Ship(name="Ship A", template=ship_templates["All Defense Tokens"], upgrades=[], player_number=1) ship_b = Ship(name="Ship B", template=ship_templates["All Defense Tokens"], upgrades=[], player_number=2) world_state.addShip(ship_a, 0) world_state.addShip(ship_b, 1) pool_colors, pool_faces = ['black'] * 4, ['hit_crit'] * 4 world_state.setPhase("ship phase", phase_name) # Set the front hull zone to 2 shields ship_b.get_range('shields')[ArmadaTypes.hull_zones.index('front')] = 2 # Set the hull to 3 (by assigning damage to reduce the remaining hull to 3) ship_b.set('damage', ship_b.get('hull') - 3) attack = AttackState('short', ship_a, 'left', ship_b, 'front', pool_colors, pool_faces) world_state.updateAttack(attack) action_encoding = torch.cat((Encodings.encodeWorldState(world_state), Encodings.encodeAction(world_state.sub_phase, None))) state_encoding = Encodings.encodeAttackState(world_state) batch[0] = torch.cat( (action_encoding.to(target_device), state_encoding.to(target_device))) action = [("brace", (ArmadaTypes.green, None))] action_encoding = torch.cat((Encodings.encodeWorldState(world_state), Encodings.encodeAction(world_state.sub_phase, action))) state_encoding = Encodings.encodeAttackState(world_state) batch[1] = torch.cat( (action_encoding.to(target_device), state_encoding.to(target_device))) world_state = WorldState() world_state.round = 1 ship_a = Ship(name="Ship A", template=ship_templates["All Defense Tokens"], upgrades=[], player_number=1) ship_b = Ship(name="Ship B", template=ship_templates["All Defense Tokens"], upgrades=[], player_number=2) world_state.addShip(ship_a, 0) world_state.addShip(ship_b, 1) pool_colors, pool_faces = ['black'] * 4, ['hit_crit'] * 2 + ['hit'] * 2 world_state.setPhase("ship phase", phase_name) # Set the front hull zone to 2 shields ship_b.get_range('shields')[ArmadaTypes.hull_zones.index('front')] = 2 # Set the hull to 3 (by assigning damage to reduce the remaining hull to 3) ship_b.set('damage', ship_b.get('hull') - 3) attack = AttackState('short', ship_a, 'left', ship_b, 'front', pool_colors, pool_faces) world_state.updateAttack(attack) action = [("redirect", (ArmadaTypes.green, [('left', 4)]))] action_encoding = torch.cat((Encodings.encodeWorldState(world_state), Encodings.encodeAction(world_state.sub_phase, action))) state_encoding = Encodings.encodeAttackState(world_state) batch[2] = torch.cat( (action_encoding.to(target_device), state_encoding.to(target_device))) round_status = network(batch[:3]) print("super cool estimated rounds of destructions are {}".format(round_status[:3])) # Using no defense token results in destruction, the final round should be less assert(round_status[0].item() < round_status[1].item()) assert(round_status[0].item() < round_status[2].item())
def test_resolve_attack_effects_model(resolve_attack_effects_model): """Test basic network learning loop. Test lifetime predictions during the resolve attack effects phase. """ network, errors, eval_errors = resolve_attack_effects_model network.eval() phase_name = "attack - resolve attack effects" world_size = Encodings.calculateWorldStateSize() action_size = Encodings.calculateActionSize(phase_name) attack_size = Encodings.calculateAttackSize() input_size = world_size + action_size + attack_size # First verify that errors decreased during training. # print("Errors for A were {}".format(errors)) print("Eval errors for A were {}".format(eval_errors)) assert eval_errors[0] > eval_errors[-1] # Let's examine predictions for different ranges and hull zones. target_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") batch_size = 4 batch = torch.Tensor(batch_size, input_size).to(target_device) # Let's examine predictions for different dice pools and spent defense tokens. # Go through the following scenarios: # 1.1 An attack upon a ship with only 1 hull remaining # 1.2 The same dice pool but on a ship with full hull # 1.3 A dice pool with only blank dice # 1.4 A dice pool with only blanks when attacking at long range. # Create a state from resolve attack effects and an empty action. world_state = WorldState() world_state.round = 1 ship_a = Ship(name="Ship A", template=ship_templates["All Defense Tokens"], upgrades=[], player_number=1) ship_b = Ship(name="Ship B", template=ship_templates["All Defense Tokens"], upgrades=[], player_number=2) world_state.addShip(ship_a, 0) world_state.addShip(ship_b, 1) pool_colors, pool_faces = ['black'] * 4, ['hit_crit'] * 4 world_state.setPhase("ship phase", "attack - resolve attack effects") ship_b.set('damage', ship_b.get('hull') - 1) attack = AttackState('short', ship_a, 'left', ship_b, 'front', pool_colors, pool_faces) world_state.updateAttack(attack) action_encoding = torch.cat((Encodings.encodeWorldState(world_state), Encodings.encodeAction(world_state.sub_phase, None))) state_encoding = Encodings.encodeAttackState(world_state) batch[0] = torch.cat( (action_encoding.to(target_device), state_encoding.to(target_device))) # Same dice pool but the defender has full hull ship_b.set('damage', 0) attack = AttackState('short', ship_a, 'left', ship_b, 'front', pool_colors, pool_faces) world_state.updateAttack(attack) action_encoding = torch.cat((Encodings.encodeWorldState(world_state), Encodings.encodeAction(world_state.sub_phase, None))) state_encoding = Encodings.encodeAttackState(world_state) batch[1] = torch.cat( (action_encoding.to(target_device), state_encoding.to(target_device))) # Full hull and all blanks pool_colors, pool_faces = ['black'] * 4, ['blank'] * 4 world_state.setPhase("ship phase", "attack - resolve attack effects") attack = AttackState('short', ship_a, 'left', ship_b, 'front', pool_colors, pool_faces) world_state.updateAttack(attack) state_encoding = Encodings.encodeAttackState(world_state) batch[2] = torch.cat( (action_encoding.to(target_device), state_encoding.to(target_device))) # Full hull, all blanks, firing at red range pool_colors, pool_faces = ['red'] * 2, ['blank'] * 2 world_state.setPhase("ship phase", "attack - resolve attack effects") attack = AttackState('long', ship_a, 'left', ship_b, 'front', pool_colors, pool_faces) world_state.updateAttack(attack) state_encoding = Encodings.encodeAttackState(world_state) batch[3] = torch.cat( (action_encoding.to(target_device), state_encoding.to(target_device))) lifetime_out = network(batch) print("super cool attack effects round estimates are {}".format(lifetime_out)) # The lifetimes should go up sequentially with the above scenarios. # However if the ship won't be destroyed the NN can't make an accurate relative number so be # lenient once lifetimes go above round 6. The first scenario should result in destruction # however. assert(lifetime_out[0].item() < 6) for i in range(batch.size(0) - 1): assert(lifetime_out[i].item() < lifetime_out[i+1].item() or (lifetime_out[i].item() > 6. and lifetime_out[i+1].item() > 6.))
def spendDefenseTokens(self, world_state): """ Args: world_state (table) : Contains the list of ships and dice pool. current_step (string) : This function only operates on the "spend defense tokens" step. Returns: (str, varies) : A tuple of the token name to spend and the token targets of die to target with the token. If no token will be spent then both values of the tuple are None and if the defense token has no target then the second value is None. For evade tokens the second value is the index of the die to target. For redirect tokens the second value is a tuple of (str, int) for the target hull zone and the amount of damage to direct to that hull zone. """ # TODO FIXME The return type is totally messed up, a list of actions should be allowable # from here # We only handle one sub-phase in this function assert world_state.sub_phase == "attack - spend defense tokens" if None == self.model: # Return no action return [] # Encode the state, forward through the network, decode the result, and return the result. as_enc = Encodings.encodeAttackState(world_state) as_enc = as_enc.to(self.device) if not self.model.with_novelty: # Forward through the policy net randomly, otherwise return random actions if self.randprob >= random.random(): # Take a random action action = self.random_agent("def_tokens", as_enc)[0] else: action = self.model.forward("def_tokens", as_enc)[0] # Remember this state action pair if in memory mode if self.remembering: self.memory.append((world_state.attack, as_enc, action)) else: action, novelty = self.model.forward("def_tokens", as_enc) # Remove the batch dimension action = action[0] novelty = novelty[0] # Remember this state action pair if in memory mode if self.remembering: self.memory.append( (world_state.attack, as_enc, action, novelty)) # Don't return the lifetime prediction (used in train_defense_tokens.py) #with torch.no_grad(): # action = torch.round(action[:Encodings.calculateSpendDefenseTokensSize()]) # Determine which tokens should be spent. Spend any above a threshold. with torch.no_grad(): spend_green = action[:len(ArmadaTypes.defense_tokens)] > math.log( 0.5) spend_red = action[len(ArmadaTypes.defense_tokens):2 * len(ArmadaTypes.defense_tokens)] > math.log(0.5) spent_tokens = spend_green + spend_red # Return now if no token will be spent if (0 == spent_tokens).all(): return [] # Handle the tokens # First check for validity. If the selected token isn't valid then use no token. # TODO Perhaps this should be penalized in some way green_idx, green_len = Ship.get_index('defense_tokens_green') red_idx, red_len = Ship.get_index('defense_tokens_red') defender_green_tokens = world_state.attack.defender.encoding[ green_idx:green_idx + green_len] defender_red_tokens = world_state.attack.defender.encoding[ red_idx:red_idx + red_len] if (spend_green > defender_green_tokens).any() or ( spend_red > defender_red_tokens).any(): return [] # Verify that these tokens have not been the target of an accuracy and that they can be # spent for idx in len(ArmadaTypes.defense_tokens): if 0 < world_state.attack.accuracy_tokens[ idx] and 0 < spent_tokens[idx]: return [] # TODO FIXME Those last two checks (for token availability and non-accuracy status) should # be enforced via a fixed input to the network that suppresses token outputs if they are not # available. This would make learning simpler. actions = [] # Handle the token, decoding the returned action based upon the token type. evade_index = ArmadaTypes.defense_tokens.index("evade") if 0 < defender_green_tokens[evade_index].item( ) + defender_red_tokens[evade_index].item(): begin = Encodings.getSpendDefenseTokensEvadeOffset() end = begin + Encodings.max_die_slots # Get the index of the maximum die response _, die_idx = action[begin:end].max(0) # Check for an invalid response from the agent # TODO Perhaps this should be penalized in some way # TODO FIXME This should also be supressed through die availability. # TODO FIXME Also handle extreme range with 2 die targets if len(die_slots) <= die_idx.item(): pass else: color = ArmadaTypes.green if 0 < defender_green_tokens[ evade_index].item() else ArmadaTypes.red src_die_slot = die_slots[die_idx.item()] actions.append(("evade", (src_die_slot))) # TODO This only supports redirecting to a single hull zone currently redir_index = ArmadaTypes.defense_tokens.index("redirect") if 0 < spent_tokens[ArmadaTypes.defense_tokens.index("redirect")]: begin = Encodings.getSpendDefenseTokensRedirectOffset() end = begin + len(ArmadaTypes.hull_zones) # The encoding has a value for each hull zone. We should check if an upgrade allows the # defender to redirect to nonadjacent or multiple hull zones, but for now we will just # handle the base case. TODO adj_hulls = world_state.attack.defender.adjacent_zones( world_state.attack.defending_hull) # Redirect to whichever hull has the greater value TODO redir_hull = None redir_amount = 0 for hull in adj_hulls: hull_redir_amount = round( action[begin + ArmadaTypes.hull_zones.index(hull)].item()) if hull_redir_amount > redir_amount: redir_hull = hull redir_amount = hull_redir_amount # Make sure there is actual redirection if None != redir_hull: color = ArmadaTypes.green if 0 < defender_green_tokens[ redir_index].item() else ArmadaTypes.red actions.append(("redirect", (redir_hull, redir_amount))) # Other defense tokens with no targets for tindx, token_type in enumerate(ArmadaTypes.defense_tokens): if token_type not in ["evade", "redirect"]: if 0 < spent_tokens[tindx]: color = ArmadaTypes.green if 0 < defender_green_tokens[ tindx].item() else ArmadaTypes.red actions.append((token_type, (color, None))) return actions
def collect_attack_batches(batch_size, attacks, subphase): """A generator to collect training batches from a list of attack logs. Collect all of the actions taken during the resolve attack effects stage of a trial and associate them with the round number when the trial ended. Only sample a single action-state pair from each trial into the training batch though. This avoids the network simply memorizing the output of specific scenarios in the event of fairly unique events (for example a certain unlikely dice roll or combination of shields and hull in the defending ship). Args: batch_size (int) : Maximum batch size. attacks (List[List[tuples]]): Each sublist is all of the states and actions from a sequence. subphase (str) : Name of the subphase where state/action pairs are collected. Returns: (batch, labels) : Tuple of the batch and labels. """ # TODO FIXME This function can be used as an dataset iterator for a multithreaded dataloader so # the batch and labels should not be passed in. Instead this function must create new tensors # for each batch. world_size = Encodings.calculateWorldStateSize() action_size = Encodings.calculateActionSize(subphase) attack_size = Encodings.calculateAttackSize() input_size = world_size + action_size + attack_size batch = torch.zeros(batch_size, input_size) labels = torch.zeros(batch_size, 1) # Variables for collection # collect_state records what we are doing inside of the sample loop collect_state = 0 # The state and (state, action) pairs collected from the current trial last_state = None state_actions_attacks = [] # Counter for the training target attack_count = 0 cur_sample = 0 last_round = 0 for sequence in attacks: # We only collect a single sample per sequence. Collect all of the state action pairs of # interest first and then select a single one to use. state_action_pairs = [] for attack in sequence: if 'state' == attack[0]: last_round = attack[1].round if attack[1].sub_phase == subphase: last_state = attack[1] collect_state = 1 else: # Not in the desired subphase and the attack trial is not complete. # Waiting for the next attack in the trial or for the trial to end. collect_state = 2 elif 'action' == attack[0] and 1 == collect_state: # Collect the actions associated with last_state. The attack count will later be # corrected to be the number of total attacks from that state rather than the current # attack. state_action_pairs.append((last_state, attack[1])) # Collect a single sample (as long as one was present) if 0 < len(state_action_pairs): selected = random.choice(state_action_pairs) # The training label is the final round where the ship was destroyed, or 7 if the ship # was not destroyed in the 6 game rounds. labels[cur_sample] = last_round world_size = Encodings.calculateWorldStateSize() action_size = Encodings.calculateActionSize(selected[0].sub_phase) attack_size = Encodings.calculateAttackSize() Encodings.encodeWorldState(world_state=selected[0], encoding=batch[cur_sample, :world_size]) Encodings.encodeAction( subphase=selected[0].sub_phase, action_list=selected[1], encoding=batch[cur_sample, world_size:world_size + action_size]) Encodings.encodeAttackState(world_state=selected[0], encoding=batch[cur_sample, world_size + action_size:]) cur_sample += 1 # When a full batch is collected return it immediately. if cur_sample == batch_size: yield ((batch[:cur_sample], labels[:cur_sample])) # Make new buffers to store training data batch = torch.zeros(batch_size, input_size) labels = torch.zeros(batch_size, 1) cur_sample = 0 # If there are leftover samples that did not fill a batch still return them. if 0 < cur_sample: yield ((batch[:cur_sample], labels[:cur_sample]))