def test_die_encodings(): """Test the dice are encoded correctly in the attack state.""" agent = RandomAgent() attacker = ship.Ship(name="Attacker", template=ship_templates["Attacker"], upgrades=[], player_number=1) template_front_dice = 0 for color in ['Red', 'Blue', 'Black']: if 0 < len(ship_templates["Attacker"][f"Armament Front {color}"]): template_front_dice += int( ship_templates["Attacker"][f"Armament Front {color}"]) dice_begin = ArmadaTypes.hull_zones.index('front') * len( ArmadaDice.die_colors) dice_end = dice_begin + len(ArmadaDice.die_colors) front_dice = int(attacker.get_range('dice')[dice_begin:dice_end].sum()) # The ship encoding should have the same dice as the template assert front_dice == template_front_dice defender = ship.Ship(name="Defender", template=ship_templates["All Defense Tokens"], upgrades=[], player_number=2) encoding, world_state = make_encoding(attacker, defender, "short", agent) attack_state_encoding = Encodings.encodeAttackState(world_state) die_offset = Encodings.getAttackDiceOffset() # The attack state should have a roll with as many dice as the ship has. dice_encoding = attack_state_encoding[die_offset:die_offset + Encodings.dieEncodingSize()] assert int(dice_encoding.sum().item()) == front_dice
def __init__(self, with_novelty=True): # An extra two outputs # are added in for the mean and variance estimation. nn.Module.__init__(self) self.with_novelty = with_novelty def_in = Encodings.calculateAttackSize() def_out = Encodings.calculateSpendDefenseTokensSize() + 2 self.models = nn.ModuleDict() self.models["def_tokens"] = self.init_fc_params(def_in, def_out) self.optimizers = { "def_tokens": torch.optim.Adam(self.models["def_tokens"].parameters()) } if self.with_novelty: # We will also create two more models to use for random distillation. The first random # model will remain static and the second will be learn to predict the outputs of the # first. The difference between the two outputs will be used to estimate the novelty of # the current state. If the state is new then the second model will not be able to make # a good prediction of the first model's outputs. In other words, the first model # projects the inputs into a new latent space. The ability of the second model to # predict the projection into the latent space should be correlated to how similar this # state is to ones we have previously visited. # The novelty network is a clone of the corresponding network self.models["def_tokens_novelty"] = nn.ModuleList([ self.init_fc_params(def_in, def_out), self.models["def_tokens"] ]) self.models["def_tokens_static"] = self.init_fc_params( def_in, def_out) self.optimizers["def_tokens_novelty"] = torch.optim.Adam( self.models["def_tokens_novelty"].parameters()) self.sm = nn.Softmax()
def test_accuracy_encodings(): """Test that the encoding is correct for dice targetted by an accuracy.""" agent = LearningAgent() attacker = ship.Ship(name="Attacker", template=ship_templates["Attacker"], upgrades=[], player_number=1) three_brace = ship.Ship(name="Double Brace", template=ship_templates["Triple Brace"], upgrades=[], player_number=2) # Make a brace token red three_brace.spend_token('brace', ArmadaTypes.green) enc_three_brace, world_state = make_encoding(attacker, three_brace, "short", agent) # Define the offsets for convenience token_begin = Encodings.getAttackTokenOffset() token_end = token_begin + ArmadaTypes.max_defense_tokens # Verify that no tokens are targeted at first assert 0.0 == enc_three_brace[token_begin:token_end].sum() # Now make a token red and target it three_brace.spend_token('brace', ArmadaTypes.green) green_acc_begin = Encodings.getAttackTokenOffset() green_acc_end = green_acc_begin + len(ArmadaTypes.defense_tokens) red_acc_begin = Encodings.getAttackTokenOffset() + len( ArmadaTypes.defense_tokens) red_acc_end = red_acc_begin + len(ArmadaTypes.defense_tokens) world_state.attack.accuracy_defender_token( ArmadaTypes.defense_tokens.index('brace'), ArmadaTypes.red) encoding = Encodings.encodeAttackState(world_state) # Verify that only the red token has the accuracy flag set assert encoding[red_acc_begin + ArmadaTypes.defense_tokens.index('brace')].item() == 1. assert encoding[red_acc_begin:red_acc_end].sum().item() == 1. assert encoding[green_acc_begin:green_acc_end].sum().item() == 0. # Target both remaining green tokens world_state.attack.accuracy_defender_token( ArmadaTypes.defense_tokens.index('brace'), ArmadaTypes.green) world_state.attack.accuracy_defender_token( ArmadaTypes.defense_tokens.index('brace'), ArmadaTypes.green) encoding = Encodings.encodeAttackState(world_state) # Verify that two green and one red brace have the accuracy flag assert encoding[red_acc_begin + ArmadaTypes.defense_tokens.index('brace')].item() == 1. assert encoding[red_acc_begin:red_acc_end].sum().item() == 1. assert encoding[green_acc_begin + ArmadaTypes.defense_tokens.index('brace')].item() == 2. assert encoding[green_acc_begin:green_acc_end].sum().item() == 2.
def test_spent_encodings(): """Test that the encoding is correct for different defense tokens.""" agent = LearningAgent() attacker = ship.Ship(name="Attacker", template=ship_templates["Attacker"], upgrades=[], player_number=1) defender = ship.Ship(name="Defender", template=ship_templates["All Defense Tokens"], upgrades=[], player_number=2) encoding, world_state = make_encoding(attacker, defender, "short", agent) # The defender and attacker come first, then the accuracied tokens, then the spent tokens spent_begin = 2 * ship.Ship.encodeSize() + 2 * len( ArmadaTypes.defense_tokens) spent_end = spent_begin + len(ArmadaTypes.defense_tokens) # Verify that no tokens are marked spent by default assert torch.sum(encoding[spent_begin:spent_end]) == 0. # Spend all of the tokens for tidx, ttype in enumerate(ArmadaTypes.defense_tokens): world_state.attack.defender_spend_token(ttype, 'green') encoding = Encodings.encodeAttackState(world_state) assert torch.sum(encoding[spent_begin:spent_end]).item() == len( ArmadaTypes.defense_tokens) # Try spending the tokens at different indices for tidx, ttype in enumerate(ArmadaTypes.defense_tokens): # Re-encode and then set the token to spent. attacker = ship.Ship(name="Attacker", template=ship_templates["Attacker"], upgrades=[], player_number=1) defender = ship.Ship(name="Defender", template=ship_templates["All Defense Tokens"], upgrades=[], player_number=2) encoding, world_state = make_encoding(attacker, defender, "short", agent) world_state.attack.defender_spend_token(ttype, 'green') encoding = Encodings.encodeAttackState(world_state) assert torch.sum(encoding[spent_begin:spent_end]).item() == 1.0 assert encoding[spent_begin:spent_end][tidx].item() == 1.0
def make_encoding(ship_a, ship_b, attack_range, agent): """This function calculates the average time to destruction when a shoots at b. Args: ship_a ((Ship, str)): Attacker and hull zone tuple. ship_b ((Ship, str)): Defender and hull zone tuple. trials (int): Number of trials in average calculation. range (str): Attack range. """ roll_counts = [] # Reset ship b for each trial world_state = WorldState() world_state.addShip(ship_a, 0) world_state.addShip(ship_b, 1) pool_colors, pool_faces = ship_a.roll("front", attack_range) attack = AttackState(attack_range=attack_range, attacker=ship_a, attacking_hull="front", defender=ship_b, defending_hull="front", pool_colors=pool_colors, pool_faces=pool_faces) world_state.updateAttack(attack) # The defense token and die locations have been reordered in the encoding, put them back to # their original ordering here. encoding = Encodings.encodeAttackState(world_state) return encoding, world_state
def __init__(self): # There is a different model for each phase. # Each model takes as input the world state and the attack state and output is the action world_size = Encodings.calculateWorldStateSize() attack_size = Encodings.calculateAttackSize() nn.Module.__init__(self) self.models = nn.ModuleDict() self.optimizers = {} # TODO FIXME Loop through all of the possible phases instead of just these two for phase_name in [ "attack - resolve attack effects", "attack - spend defense tokens" ]: action_size = Encodings.calculateActionSize(phase_name) self.models[phase_name] = self.init_fc_params( world_size + attack_size, action_size) self.optimizers[phase_name] = torch.optim.Adam( self.models[phase_name].parameters()) self.sm = nn.Softmax()
def test_roll_encodings(): """Test that the encoding is correct for dice pools and faces.""" agent = LearningAgent() attacker = ship.Ship(name="Attacker", template=ship_templates["Attacker"], upgrades=[], player_number=1) no_token = ship.Ship(name="No Defense Tokens", template=ship_templates["No Defense Tokens"], upgrades=[], player_number=2) dice_begin = Encodings.getAttackDiceOffset() # Do 100 trials to ensure everything is working as expected _, world_state = make_encoding(attacker, no_token, "short", agent) for _ in range(100): pool_colors, pool_faces = attacker.roll("front", "short") attack = world_state.attack attack.pool_faces = pool_faces attack.pool_colors = pool_colors # Count which items are matched to check if they are all encoded matched_dice = [0] * len(pool_faces) world_state.updateAttack(attack) # Make a random roll and encode the attack state # [ color - 3, face - 6] enc_attack = Encodings.encodeAttackState(world_state) # Try to find a match for each color,face pair in the encoding enc_dice = enc_attack[Encodings.getAttackDiceOffset():] for face, color in zip(attack.pool_faces, attack.pool_colors): assert 0. < enc_dice[Encodings.dieOffset(color=color, face=face)].item() enc_dice[Encodings.dieOffset(color=color, face=face)] -= 1 # All dice from the pool should have been matched and there should be no more encoded assert sum(enc_dice) <= 0.
def test_range_encodings(): """Test that the encoding is correct for ranges.""" agent = LearningAgent() attacker = ship.Ship(name="Attacker", template=ship_templates["Attacker"], upgrades=[], player_number=1) no_token = ship.Ship(name="No Defense Tokens", template=ship_templates["No Defense Tokens"], upgrades=[], player_number=2) range_begin = Encodings.getAttackRangeOffset() for offset, attack_range in enumerate(ArmadaTypes.ranges): enc_attack = make_encoding(attacker, no_token, attack_range, agent)[0] assert torch.sum(enc_attack[range_begin:range_begin + len(ArmadaTypes.ranges)]) == 1 assert 1.0 == enc_attack[range_begin + offset].item()
def __init__(self, model=None): """Initialize the simple agent with a couple of simple state handlers. Args: model (torch.nn.Module or None): If None this agent will pass for all supported states """ handler = { "ship phase - attack - resolve attack effects": self.resolveAttackEffects, "ship phase - attack - spend defense tokens": self.spendDefenseTokens } super(LearningAgent, self).__init__(handler) self.device = torch.device( "cuda" if torch.cuda.is_available() else "cpu") self.model = model if None != self.model: self.model = self.model.to(self.device) self.model.eval() self.attack_enc_size = Encodings.calculateAttackSize() self.memory = [] self.remembering = False
def update_lifetime_network(lifenet, batch, labels, optimizer, eval_only=False): """Do a forward and backward pass through the given lifetime network. Args: lifenet (torch.nn.Module): Trainable torch model batch (torch.tensor) : Training batch labels (torch.tensor) : Training labels optimizer (torch.nn.Optimizer) : Optimizer for lifenet parameters. eval_only (bool) : Only evaluate, don't update parameters. Returns: batch error : Average absolute error for this batch """ if eval_only: lifenet.eval() # Forward through the prediction network prediction = lifenet.forward(batch) # Loss is the lifetime prediction error # The output cannot be negative, run through a ReLU to clean that up #f = torch.nn.ReLU() #epsilon = 0.001 #normed_predictions = f(prediction[0]) + epsilon with torch.no_grad(): error = (prediction - labels).abs().mean().item() with torch.no_grad(): errors = (prediction - labels).abs() for i in range(errors.size(0)): # Debug on crazy errors or nan values. if errors[i] > 1000 or errors[i] != errors[i]: # This is messy debugging code, but sometimes a test may fail after bug # introductions so it is helpful to leave this in to speed up debugging. world_size = Encodings.calculateWorldStateSize() world_encoding = batch[i,:world_size] phase_name = ArmadaPhases.main_phases[int(world_encoding[1].item())] sub_phase_name = ArmadaPhases.sub_phases[phase_name][int(world_encoding[2].item())] print(f"Error {i} is {errors[i]}") print(f"\tRound {world_encoding[0].item()}") print(f"\tSubphase {sub_phase_name}") action_size = Encodings.calculateActionSize(sub_phase_name) attack_size = Encodings.calculateAttackSize() action_encoding = batch[i,world_size:world_size + action_size] attack_state_encoding = batch[i,world_size + action_size:] if "attack - resolve attack effects" == sub_phase_name: print(f"\tattack effect encoding is {action_encoding}") elif "attack - spend defense tokens" == sub_phase_name: print(f"\tspend defense token encoding is {action_encoding}") else: print("Cannot print information about {}".format(sub_phase_name)) defender = Ship(name="Defender", player_number=1, encoding=attack_state_encoding[:Ship.encodeSize()]) attacker = Ship(name="Attacker", player_number=1, encoding=attack_state_encoding[Ship.encodeSize():2 * Ship.encodeSize()]) # print(f"\tAttack state encoding is {attack_state_encoding}") print("\tAttacker is {}".format(attacker)) print("\tDefender is {}".format(defender)) # TODO FIXME Enough dice in a pool seems to end a ship, but unless the pools are # incredibly large this doesn't seem to be happening. Damage does not seem to be # accumlating between rounds. die_offset = Encodings.getAttackDiceOffset() dice_encoding = attack_state_encoding[die_offset:die_offset + Encodings.dieEncodingSize()] print("\tDice are {}".format(dice_encoding)) print(f"\tLabel is {labels[i]}") # Normal distribution: #normal = torch.distributions.normal.Normal(predict_tensor[0], predict_tensor[1]) #loss = -normal.log_prob(labels) # Poisson distribution (works well) #poisson = torch.distributions.poisson.Poisson(normed_predictions) #loss = -poisson.log_prob(labels) # Plain old MSE error #loss_fn = torch.nn.MSELoss() #loss = loss_fn(prediction, labels) # Absolute error loss_fn = torch.nn.L1Loss() loss = loss_fn(prediction, labels) if not eval_only: optimizer.zero_grad() loss.sum().backward() optimizer.step() else: lifenet.train() return error
def test_policy_learning(spend_defense_tokens_model, resolve_attack_effects_model): """Train a model to produce a better than random choice policy for defense tokens. The spend_defense_tokens_model will be used to determine the quality of this network's output. There will not be an update step as in reinforcement learning, this is just testing the mechanism. Returns: (nn.module, nn.module,): The 'resolve attack effects' and 'spend defense tokens' models. """ def_tokens_model, errors, eval_errors = spend_defense_tokens_model def_tokens_model.eval() res_attack_model, errors, eval_errors = resolve_attack_effects_model res_attack_model.eval() prediction_models = { "attack - spend defense tokens": def_tokens_model.eval(), "attack - resolve attack effects": res_attack_model.eval() } # Do the training. Use the prediction model lifetime to create the loss target. The loss # will be the difference between the max possible round and the predicted round. # For the defense token model the higher round is better, for the attack effect model a # lower round is better. loss_fn = { "attack - spend defense tokens": lambda predictions: 7.0 - predictions, "attack - resolve attack effects": lambda predictions: predictions - 1. } # Generate a new learning model learning_agent = LearningAgent(SeparatePhaseModel()) # TODO FIXME The learning agent doesn't really have a training mode learning_agent.model.train() random_agent = RandomAgent() training_ships = ["All Defense Tokens", "All Defense Tokens", "Imperial II-class Star Destroyer", "MC80 Command Cruiser", "Assault Frigate Mark II A", "No Shield Ship", "One Shield Ship", "Mega Die Ship"] defenders = [] attackers = [] for name in training_ships: attackers.append(Ship(name=name, template=ship_templates[name], upgrades=[], player_number=1, device='cpu')) for name in training_ships: defenders.append(Ship(name=name, template=ship_templates[name], upgrades=[], player_number=2, device='cpu')) batch_size = 32 # Remember the training loss values to test for improvement losses = {} for subphase in ["attack - spend defense tokens", "attack - resolve attack effects"]: losses[subphase] = [] # This gets samples to use for training. We will use a random agent to generate the states. # In reinforcement learning the agent would alternate between random actions and actions # from the learning agent to balance exploration of the state space with exploitation of # learned behavior. samples = get_n_examples( n_examples=1000, ship_a=attackers, ship_b=defenders, agent=random_agent) # Get a batch for each subphase for subphase in ["attack - spend defense tokens", "attack - resolve attack effects"]: # The learning_agent will take in the world state and attack state and will produce a new # action encoding. Then the prediction network will take in a new tuple with this action and # predict a final round. The difference between the random action and the final round with # the network's action is the reward. world_size = Encodings.calculateWorldStateSize() action_size = Encodings.calculateActionSize(subphase) attack_size = Encodings.calculateAttackSize() for batch, lifetimes in collect_attack_batches(batch_size=batch_size, attacks=samples, subphase=subphase): batch = batch.cuda() # The learning agent takes in the world state along with the action as the input tensor. new_batch = torch.cat((batch[:,:world_size], batch[:,world_size + action_size:]), dim=1).cuda() # TODO FIXME Just make a forward function that takes in a phase name new_actions = learning_agent.model.models[subphase](new_batch) new_action_state = torch.cat((batch[:,:world_size], new_actions, batch[:,world_size + action_size:]), dim=1) action_result = prediction_models[subphase](new_action_state) loss = loss_fn[subphase](action_result) learning_agent.model.get_optimizer(subphase).zero_grad() loss.sum().backward() learning_agent.model.get_optimizer(subphase).step() with torch.no_grad(): losses[subphase].append(loss.mean().item()) # In reinforcement learning there would also be a phase where the prediction models are # updated. for subphase in ["attack - spend defense tokens", "attack - resolve attack effects"]: print(f"losses for {subphase} start with {losses[subphase][0:5]} and end with {losses[subphase][-5:]}") assert losses[subphase][-1] < losses[subphase][0] # TODO FIXME HERE See if the policy networks produce better results than random actions learning_agent.model.models.eval()
def spend_defense_tokens_model(create_spend_defense_tokens_dataset): """Train some basic lifetime prediction models. Create simple networks that predict defending ship lifetimes during the 'spend defese tokens' phase and during the 'resolve attack effects' phase. Returns: (nn.module, [float], [float]): The model, training errors per epoch, and eval errors per epoch. """ attacks, eval_attacks = create_spend_defense_tokens_dataset target_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Lifetime network B predicts lifetimes given a state and action pair in the 'spend defense # tokens' subphase. phase_name = "attack - spend defense tokens" world_size = Encodings.calculateWorldStateSize() action_size = Encodings.calculateActionSize(phase_name) attack_size = Encodings.calculateAttackSize() input_size = world_size + action_size + attack_size network = torch.nn.Sequential( torch.nn.BatchNorm1d(input_size), torch.nn.Linear(input_size, 2 * input_size), torch.nn.ELU(), torch.nn.Linear(2 * input_size, 4 * input_size), torch.nn.ELU(), torch.nn.Linear(4 * input_size, 2 * input_size), torch.nn.ELU(), torch.nn.Linear(2 * input_size, 1)) network.to(target_device) optimizer = torch.optim.Adam(network.parameters(), lr=0.0005) batch_size = 32 batch = torch.Tensor(batch_size, input_size).to(target_device) # Keep track of the errors for the purpose of this test errors = [] eval_errors = [] # Evaluate before training and every epoch for batch in eval_attacks: eval_data = batch[0].to(target_device) eval_labels = batch[1].to(target_device) eval_errors.append(update_lifetime_network(network, eval_data, eval_labels, None, True)) # Train with all of the data for 10 epochs for epoch in range(10): print("Training resolve_spend_defense_tokens_model epoch {}".format(epoch)) train_batches = 0 for batch in attacks: train_data = batch[0].to(target_device) train_labels = batch[1].to(target_device) errors.append(update_lifetime_network(network, train_data, train_labels, optimizer)) train_batches += 1 print("Finished epoch with {} batches.".format(train_batches)) # Evaluate every epoch for batch in eval_attacks: eval_data = batch[0].to(target_device) eval_labels = batch[1].to(target_device) eval_errors.append(update_lifetime_network(network, eval_data, eval_labels, None, True)) return network, errors, eval_errors
def resolve_attack_effects_model(create_attack_effects_dataset): """Train some basic lifetime prediction models. Create a simple network that predict defending ship lifetimes during the 'resolve attack effects' phase. Also creates logs of training and evaluation loss. Returns: (nn.module, [float], [float]): The model, training errors per epoch, and eval errors per epoch. """ attacks, eval_attacks = create_attack_effects_dataset target_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Lifetime network A predicts lifetimes given a state and action pair in the 'resolve attack # effects' subphase. phase_name = "attack - resolve attack effects" world_size = Encodings.calculateWorldStateSize() action_size = Encodings.calculateActionSize(phase_name) attack_size = Encodings.calculateAttackSize() input_size = world_size + action_size + attack_size # The network size was made large enough that training plateaued to a stable value # If the network is too large it has a tendency to start fitting to specific cases. # Batchnorm doesn't seem to help out much with this network and task. network = torch.nn.Sequential( torch.nn.BatchNorm1d(input_size), torch.nn.Linear(input_size, 2 * input_size), torch.nn.ELU(), torch.nn.Linear(2 * input_size, 4 * input_size), torch.nn.ELU(), #torch.nn.Dropout(), #torch.nn.BatchNorm1d(4 * input_size), torch.nn.Linear(4 * input_size, 2 * input_size), torch.nn.ELU(), torch.nn.Linear(2 * input_size, 1)) network.to(target_device) # Higher learning rates lead to a lot of instability in the training. optimizer = torch.optim.Adam(network.parameters(), lr=0.0005) # Keep track of the errors for the purpose of this test errors = [] eval_errors = [] # Evaluate before training and every epoch for batch in eval_attacks: eval_data = batch[0].to(target_device) eval_labels = batch[1].to(target_device) eval_errors.append(update_lifetime_network(network, eval_data, eval_labels, None, True)) # Train with all of the data for 10 epochs for epoch in range(10): print("Training resolve_attack_effects_model epoch {}".format(epoch)) epoch_samples = 0 train_batches = 0 for batch in attacks: train_data = batch[0].to(target_device) train_labels = batch[1].to(target_device) errors.append(update_lifetime_network(network, train_data, train_labels, optimizer)) epoch_samples += batch[0].size(0) train_batches += 1 print("Finished epoch with {} batches.".format(train_batches)) # Evaluate every epoch for batch in eval_attacks: eval_data = batch[0].to(target_device) eval_labels = batch[1].to(target_device) eval_errors.append(update_lifetime_network(network, eval_data, eval_labels, None, True)) return network, errors, eval_errors
def spendDefenseTokens(self, world_state): """ Args: world_state (table) : Contains the list of ships and dice pool. current_step (string) : This function only operates on the "spend defense tokens" step. Returns: (str, varies) : A tuple of the token name to spend and the token targets of die to target with the token. If no token will be spent then both values of the tuple are None and if the defense token has no target then the second value is None. For evade tokens the second value is the index of the die to target. For redirect tokens the second value is a tuple of (str, int) for the target hull zone and the amount of damage to direct to that hull zone. """ # TODO FIXME The return type is totally messed up, a list of actions should be allowable # from here # We only handle one sub-phase in this function assert world_state.sub_phase == "attack - spend defense tokens" if None == self.model: # Return no action return [] # Encode the state, forward through the network, decode the result, and return the result. as_enc = Encodings.encodeAttackState(world_state) as_enc = as_enc.to(self.device) if not self.model.with_novelty: # Forward through the policy net randomly, otherwise return random actions if self.randprob >= random.random(): # Take a random action action = self.random_agent("def_tokens", as_enc)[0] else: action = self.model.forward("def_tokens", as_enc)[0] # Remember this state action pair if in memory mode if self.remembering: self.memory.append((world_state.attack, as_enc, action)) else: action, novelty = self.model.forward("def_tokens", as_enc) # Remove the batch dimension action = action[0] novelty = novelty[0] # Remember this state action pair if in memory mode if self.remembering: self.memory.append( (world_state.attack, as_enc, action, novelty)) # Don't return the lifetime prediction (used in train_defense_tokens.py) #with torch.no_grad(): # action = torch.round(action[:Encodings.calculateSpendDefenseTokensSize()]) # Determine which tokens should be spent. Spend any above a threshold. with torch.no_grad(): spend_green = action[:len(ArmadaTypes.defense_tokens)] > math.log( 0.5) spend_red = action[len(ArmadaTypes.defense_tokens):2 * len(ArmadaTypes.defense_tokens)] > math.log(0.5) spent_tokens = spend_green + spend_red # Return now if no token will be spent if (0 == spent_tokens).all(): return [] # Handle the tokens # First check for validity. If the selected token isn't valid then use no token. # TODO Perhaps this should be penalized in some way green_idx, green_len = Ship.get_index('defense_tokens_green') red_idx, red_len = Ship.get_index('defense_tokens_red') defender_green_tokens = world_state.attack.defender.encoding[ green_idx:green_idx + green_len] defender_red_tokens = world_state.attack.defender.encoding[ red_idx:red_idx + red_len] if (spend_green > defender_green_tokens).any() or ( spend_red > defender_red_tokens).any(): return [] # Verify that these tokens have not been the target of an accuracy and that they can be # spent for idx in len(ArmadaTypes.defense_tokens): if 0 < world_state.attack.accuracy_tokens[ idx] and 0 < spent_tokens[idx]: return [] # TODO FIXME Those last two checks (for token availability and non-accuracy status) should # be enforced via a fixed input to the network that suppresses token outputs if they are not # available. This would make learning simpler. actions = [] # Handle the token, decoding the returned action based upon the token type. evade_index = ArmadaTypes.defense_tokens.index("evade") if 0 < defender_green_tokens[evade_index].item( ) + defender_red_tokens[evade_index].item(): begin = Encodings.getSpendDefenseTokensEvadeOffset() end = begin + Encodings.max_die_slots # Get the index of the maximum die response _, die_idx = action[begin:end].max(0) # Check for an invalid response from the agent # TODO Perhaps this should be penalized in some way # TODO FIXME This should also be supressed through die availability. # TODO FIXME Also handle extreme range with 2 die targets if len(die_slots) <= die_idx.item(): pass else: color = ArmadaTypes.green if 0 < defender_green_tokens[ evade_index].item() else ArmadaTypes.red src_die_slot = die_slots[die_idx.item()] actions.append(("evade", (src_die_slot))) # TODO This only supports redirecting to a single hull zone currently redir_index = ArmadaTypes.defense_tokens.index("redirect") if 0 < spent_tokens[ArmadaTypes.defense_tokens.index("redirect")]: begin = Encodings.getSpendDefenseTokensRedirectOffset() end = begin + len(ArmadaTypes.hull_zones) # The encoding has a value for each hull zone. We should check if an upgrade allows the # defender to redirect to nonadjacent or multiple hull zones, but for now we will just # handle the base case. TODO adj_hulls = world_state.attack.defender.adjacent_zones( world_state.attack.defending_hull) # Redirect to whichever hull has the greater value TODO redir_hull = None redir_amount = 0 for hull in adj_hulls: hull_redir_amount = round( action[begin + ArmadaTypes.hull_zones.index(hull)].item()) if hull_redir_amount > redir_amount: redir_hull = hull redir_amount = hull_redir_amount # Make sure there is actual redirection if None != redir_hull: color = ArmadaTypes.green if 0 < defender_green_tokens[ redir_index].item() else ArmadaTypes.red actions.append(("redirect", (redir_hull, redir_amount))) # Other defense tokens with no targets for tindx, token_type in enumerate(ArmadaTypes.defense_tokens): if token_type not in ["evade", "redirect"]: if 0 < spent_tokens[tindx]: color = ArmadaTypes.green if 0 < defender_green_tokens[ tindx].item() else ArmadaTypes.red actions.append((token_type, (color, None))) return actions
def test_resolve_attack_effects_model(resolve_attack_effects_model): """Test basic network learning loop. Test lifetime predictions during the resolve attack effects phase. """ network, errors, eval_errors = resolve_attack_effects_model network.eval() phase_name = "attack - resolve attack effects" world_size = Encodings.calculateWorldStateSize() action_size = Encodings.calculateActionSize(phase_name) attack_size = Encodings.calculateAttackSize() input_size = world_size + action_size + attack_size # First verify that errors decreased during training. # print("Errors for A were {}".format(errors)) print("Eval errors for A were {}".format(eval_errors)) assert eval_errors[0] > eval_errors[-1] # Let's examine predictions for different ranges and hull zones. target_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") batch_size = 4 batch = torch.Tensor(batch_size, input_size).to(target_device) # Let's examine predictions for different dice pools and spent defense tokens. # Go through the following scenarios: # 1.1 An attack upon a ship with only 1 hull remaining # 1.2 The same dice pool but on a ship with full hull # 1.3 A dice pool with only blank dice # 1.4 A dice pool with only blanks when attacking at long range. # Create a state from resolve attack effects and an empty action. world_state = WorldState() world_state.round = 1 ship_a = Ship(name="Ship A", template=ship_templates["All Defense Tokens"], upgrades=[], player_number=1) ship_b = Ship(name="Ship B", template=ship_templates["All Defense Tokens"], upgrades=[], player_number=2) world_state.addShip(ship_a, 0) world_state.addShip(ship_b, 1) pool_colors, pool_faces = ['black'] * 4, ['hit_crit'] * 4 world_state.setPhase("ship phase", "attack - resolve attack effects") ship_b.set('damage', ship_b.get('hull') - 1) attack = AttackState('short', ship_a, 'left', ship_b, 'front', pool_colors, pool_faces) world_state.updateAttack(attack) action_encoding = torch.cat((Encodings.encodeWorldState(world_state), Encodings.encodeAction(world_state.sub_phase, None))) state_encoding = Encodings.encodeAttackState(world_state) batch[0] = torch.cat( (action_encoding.to(target_device), state_encoding.to(target_device))) # Same dice pool but the defender has full hull ship_b.set('damage', 0) attack = AttackState('short', ship_a, 'left', ship_b, 'front', pool_colors, pool_faces) world_state.updateAttack(attack) action_encoding = torch.cat((Encodings.encodeWorldState(world_state), Encodings.encodeAction(world_state.sub_phase, None))) state_encoding = Encodings.encodeAttackState(world_state) batch[1] = torch.cat( (action_encoding.to(target_device), state_encoding.to(target_device))) # Full hull and all blanks pool_colors, pool_faces = ['black'] * 4, ['blank'] * 4 world_state.setPhase("ship phase", "attack - resolve attack effects") attack = AttackState('short', ship_a, 'left', ship_b, 'front', pool_colors, pool_faces) world_state.updateAttack(attack) state_encoding = Encodings.encodeAttackState(world_state) batch[2] = torch.cat( (action_encoding.to(target_device), state_encoding.to(target_device))) # Full hull, all blanks, firing at red range pool_colors, pool_faces = ['red'] * 2, ['blank'] * 2 world_state.setPhase("ship phase", "attack - resolve attack effects") attack = AttackState('long', ship_a, 'left', ship_b, 'front', pool_colors, pool_faces) world_state.updateAttack(attack) state_encoding = Encodings.encodeAttackState(world_state) batch[3] = torch.cat( (action_encoding.to(target_device), state_encoding.to(target_device))) lifetime_out = network(batch) print("super cool attack effects round estimates are {}".format(lifetime_out)) # The lifetimes should go up sequentially with the above scenarios. # However if the ship won't be destroyed the NN can't make an accurate relative number so be # lenient once lifetimes go above round 6. The first scenario should result in destruction # however. assert(lifetime_out[0].item() < 6) for i in range(batch.size(0) - 1): assert(lifetime_out[i].item() < lifetime_out[i+1].item() or (lifetime_out[i].item() > 6. and lifetime_out[i+1].item() > 6.))
def test_defense_tokens_model(spend_defense_tokens_model): """Test basic network learning loop. Test lifetime predictions during the spend defense tokens phase. """ network, errors, eval_errors = spend_defense_tokens_model network.eval() phase_name = "attack - spend defense tokens" world_size = Encodings.calculateWorldStateSize() action_size = Encodings.calculateActionSize(phase_name) attack_size = Encodings.calculateAttackSize() input_size = world_size + action_size + attack_size target_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") batch_size = 32 batch = torch.Tensor(batch_size, input_size).to(target_device) print("Eval errors for B were {}".format(eval_errors)) #print("First and last errors in b are {} and {}".format(eval_errors[0], eval_errors[-1])) assert eval_errors[0] > eval_errors[-1] # Let's examine predictions for different dice pools and spent defense tokens. # Go through the following scenarios: # 1.1 An attack with more than enough damage to destroy the ship # 1.2 The same attack but a brace that would prevent destruction # 1.3 The same attack but a redirect that would prevent destruction # Result: 1.1 should have lower lifetime than 1.2 and 1.3 # 2.1 An attack that can barely destroy the ship # 2.2 An attack that barely will not destroy the ship # Result: 2.1 should have lower lifetime than 2.2. # Ideally 1.1 and 2.1 would predict the current round. world_state = WorldState() world_state.round = 1 ship_a = Ship(name="Ship A", template=ship_templates["All Defense Tokens"], upgrades=[], player_number=1) ship_b = Ship(name="Ship B", template=ship_templates["All Defense Tokens"], upgrades=[], player_number=2) world_state.addShip(ship_a, 0) world_state.addShip(ship_b, 1) pool_colors, pool_faces = ['black'] * 4, ['hit_crit'] * 4 world_state.setPhase("ship phase", phase_name) # Set the front hull zone to 2 shields ship_b.get_range('shields')[ArmadaTypes.hull_zones.index('front')] = 2 # Set the hull to 3 (by assigning damage to reduce the remaining hull to 3) ship_b.set('damage', ship_b.get('hull') - 3) attack = AttackState('short', ship_a, 'left', ship_b, 'front', pool_colors, pool_faces) world_state.updateAttack(attack) action_encoding = torch.cat((Encodings.encodeWorldState(world_state), Encodings.encodeAction(world_state.sub_phase, None))) state_encoding = Encodings.encodeAttackState(world_state) batch[0] = torch.cat( (action_encoding.to(target_device), state_encoding.to(target_device))) action = [("brace", (ArmadaTypes.green, None))] action_encoding = torch.cat((Encodings.encodeWorldState(world_state), Encodings.encodeAction(world_state.sub_phase, action))) state_encoding = Encodings.encodeAttackState(world_state) batch[1] = torch.cat( (action_encoding.to(target_device), state_encoding.to(target_device))) world_state = WorldState() world_state.round = 1 ship_a = Ship(name="Ship A", template=ship_templates["All Defense Tokens"], upgrades=[], player_number=1) ship_b = Ship(name="Ship B", template=ship_templates["All Defense Tokens"], upgrades=[], player_number=2) world_state.addShip(ship_a, 0) world_state.addShip(ship_b, 1) pool_colors, pool_faces = ['black'] * 4, ['hit_crit'] * 2 + ['hit'] * 2 world_state.setPhase("ship phase", phase_name) # Set the front hull zone to 2 shields ship_b.get_range('shields')[ArmadaTypes.hull_zones.index('front')] = 2 # Set the hull to 3 (by assigning damage to reduce the remaining hull to 3) ship_b.set('damage', ship_b.get('hull') - 3) attack = AttackState('short', ship_a, 'left', ship_b, 'front', pool_colors, pool_faces) world_state.updateAttack(attack) action = [("redirect", (ArmadaTypes.green, [('left', 4)]))] action_encoding = torch.cat((Encodings.encodeWorldState(world_state), Encodings.encodeAction(world_state.sub_phase, action))) state_encoding = Encodings.encodeAttackState(world_state) batch[2] = torch.cat( (action_encoding.to(target_device), state_encoding.to(target_device))) round_status = network(batch[:3]) print("super cool estimated rounds of destructions are {}".format(round_status[:3])) # Using no defense token results in destruction, the final round should be less assert(round_status[0].item() < round_status[1].item()) assert(round_status[0].item() < round_status[2].item())
def collect_attack_batches(batch_size, attacks, subphase): """A generator to collect training batches from a list of attack logs. Collect all of the actions taken during the resolve attack effects stage of a trial and associate them with the round number when the trial ended. Only sample a single action-state pair from each trial into the training batch though. This avoids the network simply memorizing the output of specific scenarios in the event of fairly unique events (for example a certain unlikely dice roll or combination of shields and hull in the defending ship). Args: batch_size (int) : Maximum batch size. attacks (List[List[tuples]]): Each sublist is all of the states and actions from a sequence. subphase (str) : Name of the subphase where state/action pairs are collected. Returns: (batch, labels) : Tuple of the batch and labels. """ # TODO FIXME This function can be used as an dataset iterator for a multithreaded dataloader so # the batch and labels should not be passed in. Instead this function must create new tensors # for each batch. world_size = Encodings.calculateWorldStateSize() action_size = Encodings.calculateActionSize(subphase) attack_size = Encodings.calculateAttackSize() input_size = world_size + action_size + attack_size batch = torch.zeros(batch_size, input_size) labels = torch.zeros(batch_size, 1) # Variables for collection # collect_state records what we are doing inside of the sample loop collect_state = 0 # The state and (state, action) pairs collected from the current trial last_state = None state_actions_attacks = [] # Counter for the training target attack_count = 0 cur_sample = 0 last_round = 0 for sequence in attacks: # We only collect a single sample per sequence. Collect all of the state action pairs of # interest first and then select a single one to use. state_action_pairs = [] for attack in sequence: if 'state' == attack[0]: last_round = attack[1].round if attack[1].sub_phase == subphase: last_state = attack[1] collect_state = 1 else: # Not in the desired subphase and the attack trial is not complete. # Waiting for the next attack in the trial or for the trial to end. collect_state = 2 elif 'action' == attack[0] and 1 == collect_state: # Collect the actions associated with last_state. The attack count will later be # corrected to be the number of total attacks from that state rather than the current # attack. state_action_pairs.append((last_state, attack[1])) # Collect a single sample (as long as one was present) if 0 < len(state_action_pairs): selected = random.choice(state_action_pairs) # The training label is the final round where the ship was destroyed, or 7 if the ship # was not destroyed in the 6 game rounds. labels[cur_sample] = last_round world_size = Encodings.calculateWorldStateSize() action_size = Encodings.calculateActionSize(selected[0].sub_phase) attack_size = Encodings.calculateAttackSize() Encodings.encodeWorldState(world_state=selected[0], encoding=batch[cur_sample, :world_size]) Encodings.encodeAction( subphase=selected[0].sub_phase, action_list=selected[1], encoding=batch[cur_sample, world_size:world_size + action_size]) Encodings.encodeAttackState(world_state=selected[0], encoding=batch[cur_sample, world_size + action_size:]) cur_sample += 1 # When a full batch is collected return it immediately. if cur_sample == batch_size: yield ((batch[:cur_sample], labels[:cur_sample])) # Make new buffers to store training data batch = torch.zeros(batch_size, input_size) labels = torch.zeros(batch_size, 1) cur_sample = 0 # If there are leftover samples that did not fill a batch still return them. if 0 < cur_sample: yield ((batch[:cur_sample], labels[:cur_sample]))