def test_die_encodings(): """Test the dice are encoded correctly in the attack state.""" agent = RandomAgent() attacker = ship.Ship(name="Attacker", template=ship_templates["Attacker"], upgrades=[], player_number=1) template_front_dice = 0 for color in ['Red', 'Blue', 'Black']: if 0 < len(ship_templates["Attacker"][f"Armament Front {color}"]): template_front_dice += int( ship_templates["Attacker"][f"Armament Front {color}"]) dice_begin = ArmadaTypes.hull_zones.index('front') * len( ArmadaDice.die_colors) dice_end = dice_begin + len(ArmadaDice.die_colors) front_dice = int(attacker.get_range('dice')[dice_begin:dice_end].sum()) # The ship encoding should have the same dice as the template assert front_dice == template_front_dice defender = ship.Ship(name="Defender", template=ship_templates["All Defense Tokens"], upgrades=[], player_number=2) encoding, world_state = make_encoding(attacker, defender, "short", agent) attack_state_encoding = Encodings.encodeAttackState(world_state) die_offset = Encodings.getAttackDiceOffset() # The attack state should have a roll with as many dice as the ship has. dice_encoding = attack_state_encoding[die_offset:die_offset + Encodings.dieEncodingSize()] assert int(dice_encoding.sum().item()) == front_dice
def update_lifetime_network(lifenet, batch, labels, optimizer, eval_only=False): """Do a forward and backward pass through the given lifetime network. Args: lifenet (torch.nn.Module): Trainable torch model batch (torch.tensor) : Training batch labels (torch.tensor) : Training labels optimizer (torch.nn.Optimizer) : Optimizer for lifenet parameters. eval_only (bool) : Only evaluate, don't update parameters. Returns: batch error : Average absolute error for this batch """ if eval_only: lifenet.eval() # Forward through the prediction network prediction = lifenet.forward(batch) # Loss is the lifetime prediction error # The output cannot be negative, run through a ReLU to clean that up #f = torch.nn.ReLU() #epsilon = 0.001 #normed_predictions = f(prediction[0]) + epsilon with torch.no_grad(): error = (prediction - labels).abs().mean().item() with torch.no_grad(): errors = (prediction - labels).abs() for i in range(errors.size(0)): # Debug on crazy errors or nan values. if errors[i] > 1000 or errors[i] != errors[i]: # This is messy debugging code, but sometimes a test may fail after bug # introductions so it is helpful to leave this in to speed up debugging. world_size = Encodings.calculateWorldStateSize() world_encoding = batch[i,:world_size] phase_name = ArmadaPhases.main_phases[int(world_encoding[1].item())] sub_phase_name = ArmadaPhases.sub_phases[phase_name][int(world_encoding[2].item())] print(f"Error {i} is {errors[i]}") print(f"\tRound {world_encoding[0].item()}") print(f"\tSubphase {sub_phase_name}") action_size = Encodings.calculateActionSize(sub_phase_name) attack_size = Encodings.calculateAttackSize() action_encoding = batch[i,world_size:world_size + action_size] attack_state_encoding = batch[i,world_size + action_size:] if "attack - resolve attack effects" == sub_phase_name: print(f"\tattack effect encoding is {action_encoding}") elif "attack - spend defense tokens" == sub_phase_name: print(f"\tspend defense token encoding is {action_encoding}") else: print("Cannot print information about {}".format(sub_phase_name)) defender = Ship(name="Defender", player_number=1, encoding=attack_state_encoding[:Ship.encodeSize()]) attacker = Ship(name="Attacker", player_number=1, encoding=attack_state_encoding[Ship.encodeSize():2 * Ship.encodeSize()]) # print(f"\tAttack state encoding is {attack_state_encoding}") print("\tAttacker is {}".format(attacker)) print("\tDefender is {}".format(defender)) # TODO FIXME Enough dice in a pool seems to end a ship, but unless the pools are # incredibly large this doesn't seem to be happening. Damage does not seem to be # accumlating between rounds. die_offset = Encodings.getAttackDiceOffset() dice_encoding = attack_state_encoding[die_offset:die_offset + Encodings.dieEncodingSize()] print("\tDice are {}".format(dice_encoding)) print(f"\tLabel is {labels[i]}") # Normal distribution: #normal = torch.distributions.normal.Normal(predict_tensor[0], predict_tensor[1]) #loss = -normal.log_prob(labels) # Poisson distribution (works well) #poisson = torch.distributions.poisson.Poisson(normed_predictions) #loss = -poisson.log_prob(labels) # Plain old MSE error #loss_fn = torch.nn.MSELoss() #loss = loss_fn(prediction, labels) # Absolute error loss_fn = torch.nn.L1Loss() loss = loss_fn(prediction, labels) if not eval_only: optimizer.zero_grad() loss.sum().backward() optimizer.step() else: lifenet.train() return error