示例#1
0
def test_die_encodings():
    """Test the dice are encoded correctly in the attack state."""

    agent = RandomAgent()
    attacker = ship.Ship(name="Attacker",
                         template=ship_templates["Attacker"],
                         upgrades=[],
                         player_number=1)
    template_front_dice = 0
    for color in ['Red', 'Blue', 'Black']:
        if 0 < len(ship_templates["Attacker"][f"Armament Front {color}"]):
            template_front_dice += int(
                ship_templates["Attacker"][f"Armament Front {color}"])

    dice_begin = ArmadaTypes.hull_zones.index('front') * len(
        ArmadaDice.die_colors)
    dice_end = dice_begin + len(ArmadaDice.die_colors)
    front_dice = int(attacker.get_range('dice')[dice_begin:dice_end].sum())
    # The ship encoding should have the same dice as the template
    assert front_dice == template_front_dice

    defender = ship.Ship(name="Defender",
                         template=ship_templates["All Defense Tokens"],
                         upgrades=[],
                         player_number=2)

    encoding, world_state = make_encoding(attacker, defender, "short", agent)

    attack_state_encoding = Encodings.encodeAttackState(world_state)
    die_offset = Encodings.getAttackDiceOffset()
    # The attack state should have a roll with as many dice as the ship has.
    dice_encoding = attack_state_encoding[die_offset:die_offset +
                                          Encodings.dieEncodingSize()]
    assert int(dice_encoding.sum().item()) == front_dice
def update_lifetime_network(lifenet, batch, labels, optimizer, eval_only=False):
    """Do a forward and backward pass through the given lifetime network.

    Args:
        lifenet (torch.nn.Module): Trainable torch model
        batch (torch.tensor)     : Training batch
        labels (torch.tensor)    : Training labels
        optimizer (torch.nn.Optimizer) : Optimizer for lifenet parameters.
        eval_only (bool)         : Only evaluate, don't update parameters.
    Returns:
        batch error              : Average absolute error for this batch
    """
    if eval_only:
        lifenet.eval()
    # Forward through the prediction network
    prediction = lifenet.forward(batch)
    # Loss is the lifetime prediction error
    # The output cannot be negative, run through a ReLU to clean that up
    #f = torch.nn.ReLU()
    #epsilon = 0.001
    #normed_predictions = f(prediction[0]) + epsilon
    with torch.no_grad():
        error = (prediction - labels).abs().mean().item()

    with torch.no_grad():
        errors = (prediction - labels).abs()
        for i in range(errors.size(0)):
            # Debug on crazy errors or nan values.
            if errors[i] > 1000 or errors[i] != errors[i]:
                # This is messy debugging code, but sometimes a test may fail after bug
                # introductions so it is helpful to leave this in to speed up debugging.
                world_size = Encodings.calculateWorldStateSize()
                world_encoding = batch[i,:world_size]
                phase_name = ArmadaPhases.main_phases[int(world_encoding[1].item())]
                sub_phase_name = ArmadaPhases.sub_phases[phase_name][int(world_encoding[2].item())]
                print(f"Error {i} is {errors[i]}")
                print(f"\tRound {world_encoding[0].item()}")
                print(f"\tSubphase {sub_phase_name}")
                action_size = Encodings.calculateActionSize(sub_phase_name)
                attack_size = Encodings.calculateAttackSize()
                action_encoding = batch[i,world_size:world_size + action_size]
                attack_state_encoding = batch[i,world_size + action_size:]
                if "attack - resolve attack effects" == sub_phase_name:
                    print(f"\tattack effect encoding is {action_encoding}")
                elif "attack - spend defense tokens" == sub_phase_name:
                    print(f"\tspend defense token encoding is {action_encoding}")
                else:
                    print("Cannot print information about {}".format(sub_phase_name))
                defender = Ship(name="Defender", player_number=1,
                                encoding=attack_state_encoding[:Ship.encodeSize()])
                attacker = Ship(name="Attacker", player_number=1,
                                encoding=attack_state_encoding[Ship.encodeSize():2 * Ship.encodeSize()])
                # print(f"\tAttack state encoding is {attack_state_encoding}")
                print("\tAttacker is {}".format(attacker))
                print("\tDefender is {}".format(defender))
                # TODO FIXME Enough dice in a pool seems to end a ship, but unless the pools are
                # incredibly large this doesn't seem to be happening. Damage does not seem to be
                # accumlating between rounds.
                die_offset = Encodings.getAttackDiceOffset()
                dice_encoding = attack_state_encoding[die_offset:die_offset + Encodings.dieEncodingSize()]
                print("\tDice are {}".format(dice_encoding))
                print(f"\tLabel is {labels[i]}")


    # Normal distribution:
    #normal = torch.distributions.normal.Normal(predict_tensor[0], predict_tensor[1])
    #loss = -normal.log_prob(labels)
    # Poisson distribution (works well)
    #poisson = torch.distributions.poisson.Poisson(normed_predictions)
    #loss = -poisson.log_prob(labels)
    # Plain old MSE error
    #loss_fn = torch.nn.MSELoss()
    #loss = loss_fn(prediction, labels)
    # Absolute error
    loss_fn = torch.nn.L1Loss()
    loss = loss_fn(prediction, labels)

    if not eval_only:
        optimizer.zero_grad()
        loss.sum().backward()
        optimizer.step()
    else:
        lifenet.train()
    return error