def test_resolve_attack_effects_model(resolve_attack_effects_model): """Test basic network learning loop. Test lifetime predictions during the resolve attack effects phase. """ network, errors, eval_errors = resolve_attack_effects_model network.eval() phase_name = "attack - resolve attack effects" world_size = Encodings.calculateWorldStateSize() action_size = Encodings.calculateActionSize(phase_name) attack_size = Encodings.calculateAttackSize() input_size = world_size + action_size + attack_size # First verify that errors decreased during training. # print("Errors for A were {}".format(errors)) print("Eval errors for A were {}".format(eval_errors)) assert eval_errors[0] > eval_errors[-1] # Let's examine predictions for different ranges and hull zones. target_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") batch_size = 4 batch = torch.Tensor(batch_size, input_size).to(target_device) # Let's examine predictions for different dice pools and spent defense tokens. # Go through the following scenarios: # 1.1 An attack upon a ship with only 1 hull remaining # 1.2 The same dice pool but on a ship with full hull # 1.3 A dice pool with only blank dice # 1.4 A dice pool with only blanks when attacking at long range. # Create a state from resolve attack effects and an empty action. world_state = WorldState() world_state.round = 1 ship_a = Ship(name="Ship A", template=ship_templates["All Defense Tokens"], upgrades=[], player_number=1) ship_b = Ship(name="Ship B", template=ship_templates["All Defense Tokens"], upgrades=[], player_number=2) world_state.addShip(ship_a, 0) world_state.addShip(ship_b, 1) pool_colors, pool_faces = ['black'] * 4, ['hit_crit'] * 4 world_state.setPhase("ship phase", "attack - resolve attack effects") ship_b.set('damage', ship_b.get('hull') - 1) attack = AttackState('short', ship_a, 'left', ship_b, 'front', pool_colors, pool_faces) world_state.updateAttack(attack) action_encoding = torch.cat((Encodings.encodeWorldState(world_state), Encodings.encodeAction(world_state.sub_phase, None))) state_encoding = Encodings.encodeAttackState(world_state) batch[0] = torch.cat( (action_encoding.to(target_device), state_encoding.to(target_device))) # Same dice pool but the defender has full hull ship_b.set('damage', 0) attack = AttackState('short', ship_a, 'left', ship_b, 'front', pool_colors, pool_faces) world_state.updateAttack(attack) action_encoding = torch.cat((Encodings.encodeWorldState(world_state), Encodings.encodeAction(world_state.sub_phase, None))) state_encoding = Encodings.encodeAttackState(world_state) batch[1] = torch.cat( (action_encoding.to(target_device), state_encoding.to(target_device))) # Full hull and all blanks pool_colors, pool_faces = ['black'] * 4, ['blank'] * 4 world_state.setPhase("ship phase", "attack - resolve attack effects") attack = AttackState('short', ship_a, 'left', ship_b, 'front', pool_colors, pool_faces) world_state.updateAttack(attack) state_encoding = Encodings.encodeAttackState(world_state) batch[2] = torch.cat( (action_encoding.to(target_device), state_encoding.to(target_device))) # Full hull, all blanks, firing at red range pool_colors, pool_faces = ['red'] * 2, ['blank'] * 2 world_state.setPhase("ship phase", "attack - resolve attack effects") attack = AttackState('long', ship_a, 'left', ship_b, 'front', pool_colors, pool_faces) world_state.updateAttack(attack) state_encoding = Encodings.encodeAttackState(world_state) batch[3] = torch.cat( (action_encoding.to(target_device), state_encoding.to(target_device))) lifetime_out = network(batch) print("super cool attack effects round estimates are {}".format(lifetime_out)) # The lifetimes should go up sequentially with the above scenarios. # However if the ship won't be destroyed the NN can't make an accurate relative number so be # lenient once lifetimes go above round 6. The first scenario should result in destruction # however. assert(lifetime_out[0].item() < 6) for i in range(batch.size(0) - 1): assert(lifetime_out[i].item() < lifetime_out[i+1].item() or (lifetime_out[i].item() > 6. and lifetime_out[i+1].item() > 6.))
def test_defense_tokens_model(spend_defense_tokens_model): """Test basic network learning loop. Test lifetime predictions during the spend defense tokens phase. """ network, errors, eval_errors = spend_defense_tokens_model network.eval() phase_name = "attack - spend defense tokens" world_size = Encodings.calculateWorldStateSize() action_size = Encodings.calculateActionSize(phase_name) attack_size = Encodings.calculateAttackSize() input_size = world_size + action_size + attack_size target_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") batch_size = 32 batch = torch.Tensor(batch_size, input_size).to(target_device) print("Eval errors for B were {}".format(eval_errors)) #print("First and last errors in b are {} and {}".format(eval_errors[0], eval_errors[-1])) assert eval_errors[0] > eval_errors[-1] # Let's examine predictions for different dice pools and spent defense tokens. # Go through the following scenarios: # 1.1 An attack with more than enough damage to destroy the ship # 1.2 The same attack but a brace that would prevent destruction # 1.3 The same attack but a redirect that would prevent destruction # Result: 1.1 should have lower lifetime than 1.2 and 1.3 # 2.1 An attack that can barely destroy the ship # 2.2 An attack that barely will not destroy the ship # Result: 2.1 should have lower lifetime than 2.2. # Ideally 1.1 and 2.1 would predict the current round. world_state = WorldState() world_state.round = 1 ship_a = Ship(name="Ship A", template=ship_templates["All Defense Tokens"], upgrades=[], player_number=1) ship_b = Ship(name="Ship B", template=ship_templates["All Defense Tokens"], upgrades=[], player_number=2) world_state.addShip(ship_a, 0) world_state.addShip(ship_b, 1) pool_colors, pool_faces = ['black'] * 4, ['hit_crit'] * 4 world_state.setPhase("ship phase", phase_name) # Set the front hull zone to 2 shields ship_b.get_range('shields')[ArmadaTypes.hull_zones.index('front')] = 2 # Set the hull to 3 (by assigning damage to reduce the remaining hull to 3) ship_b.set('damage', ship_b.get('hull') - 3) attack = AttackState('short', ship_a, 'left', ship_b, 'front', pool_colors, pool_faces) world_state.updateAttack(attack) action_encoding = torch.cat((Encodings.encodeWorldState(world_state), Encodings.encodeAction(world_state.sub_phase, None))) state_encoding = Encodings.encodeAttackState(world_state) batch[0] = torch.cat( (action_encoding.to(target_device), state_encoding.to(target_device))) action = [("brace", (ArmadaTypes.green, None))] action_encoding = torch.cat((Encodings.encodeWorldState(world_state), Encodings.encodeAction(world_state.sub_phase, action))) state_encoding = Encodings.encodeAttackState(world_state) batch[1] = torch.cat( (action_encoding.to(target_device), state_encoding.to(target_device))) world_state = WorldState() world_state.round = 1 ship_a = Ship(name="Ship A", template=ship_templates["All Defense Tokens"], upgrades=[], player_number=1) ship_b = Ship(name="Ship B", template=ship_templates["All Defense Tokens"], upgrades=[], player_number=2) world_state.addShip(ship_a, 0) world_state.addShip(ship_b, 1) pool_colors, pool_faces = ['black'] * 4, ['hit_crit'] * 2 + ['hit'] * 2 world_state.setPhase("ship phase", phase_name) # Set the front hull zone to 2 shields ship_b.get_range('shields')[ArmadaTypes.hull_zones.index('front')] = 2 # Set the hull to 3 (by assigning damage to reduce the remaining hull to 3) ship_b.set('damage', ship_b.get('hull') - 3) attack = AttackState('short', ship_a, 'left', ship_b, 'front', pool_colors, pool_faces) world_state.updateAttack(attack) action = [("redirect", (ArmadaTypes.green, [('left', 4)]))] action_encoding = torch.cat((Encodings.encodeWorldState(world_state), Encodings.encodeAction(world_state.sub_phase, action))) state_encoding = Encodings.encodeAttackState(world_state) batch[2] = torch.cat( (action_encoding.to(target_device), state_encoding.to(target_device))) round_status = network(batch[:3]) print("super cool estimated rounds of destructions are {}".format(round_status[:3])) # Using no defense token results in destruction, the final round should be less assert(round_status[0].item() < round_status[1].item()) assert(round_status[0].item() < round_status[2].item())
def collect_attack_batches(batch_size, attacks, subphase): """A generator to collect training batches from a list of attack logs. Collect all of the actions taken during the resolve attack effects stage of a trial and associate them with the round number when the trial ended. Only sample a single action-state pair from each trial into the training batch though. This avoids the network simply memorizing the output of specific scenarios in the event of fairly unique events (for example a certain unlikely dice roll or combination of shields and hull in the defending ship). Args: batch_size (int) : Maximum batch size. attacks (List[List[tuples]]): Each sublist is all of the states and actions from a sequence. subphase (str) : Name of the subphase where state/action pairs are collected. Returns: (batch, labels) : Tuple of the batch and labels. """ # TODO FIXME This function can be used as an dataset iterator for a multithreaded dataloader so # the batch and labels should not be passed in. Instead this function must create new tensors # for each batch. world_size = Encodings.calculateWorldStateSize() action_size = Encodings.calculateActionSize(subphase) attack_size = Encodings.calculateAttackSize() input_size = world_size + action_size + attack_size batch = torch.zeros(batch_size, input_size) labels = torch.zeros(batch_size, 1) # Variables for collection # collect_state records what we are doing inside of the sample loop collect_state = 0 # The state and (state, action) pairs collected from the current trial last_state = None state_actions_attacks = [] # Counter for the training target attack_count = 0 cur_sample = 0 last_round = 0 for sequence in attacks: # We only collect a single sample per sequence. Collect all of the state action pairs of # interest first and then select a single one to use. state_action_pairs = [] for attack in sequence: if 'state' == attack[0]: last_round = attack[1].round if attack[1].sub_phase == subphase: last_state = attack[1] collect_state = 1 else: # Not in the desired subphase and the attack trial is not complete. # Waiting for the next attack in the trial or for the trial to end. collect_state = 2 elif 'action' == attack[0] and 1 == collect_state: # Collect the actions associated with last_state. The attack count will later be # corrected to be the number of total attacks from that state rather than the current # attack. state_action_pairs.append((last_state, attack[1])) # Collect a single sample (as long as one was present) if 0 < len(state_action_pairs): selected = random.choice(state_action_pairs) # The training label is the final round where the ship was destroyed, or 7 if the ship # was not destroyed in the 6 game rounds. labels[cur_sample] = last_round world_size = Encodings.calculateWorldStateSize() action_size = Encodings.calculateActionSize(selected[0].sub_phase) attack_size = Encodings.calculateAttackSize() Encodings.encodeWorldState(world_state=selected[0], encoding=batch[cur_sample, :world_size]) Encodings.encodeAction( subphase=selected[0].sub_phase, action_list=selected[1], encoding=batch[cur_sample, world_size:world_size + action_size]) Encodings.encodeAttackState(world_state=selected[0], encoding=batch[cur_sample, world_size + action_size:]) cur_sample += 1 # When a full batch is collected return it immediately. if cur_sample == batch_size: yield ((batch[:cur_sample], labels[:cur_sample])) # Make new buffers to store training data batch = torch.zeros(batch_size, input_size) labels = torch.zeros(batch_size, 1) cur_sample = 0 # If there are leftover samples that did not fill a batch still return them. if 0 < cur_sample: yield ((batch[:cur_sample], labels[:cur_sample]))