示例#1
0
def test_die_encodings():
    """Test the dice are encoded correctly in the attack state."""

    agent = RandomAgent()
    attacker = ship.Ship(name="Attacker",
                         template=ship_templates["Attacker"],
                         upgrades=[],
                         player_number=1)
    template_front_dice = 0
    for color in ['Red', 'Blue', 'Black']:
        if 0 < len(ship_templates["Attacker"][f"Armament Front {color}"]):
            template_front_dice += int(
                ship_templates["Attacker"][f"Armament Front {color}"])

    dice_begin = ArmadaTypes.hull_zones.index('front') * len(
        ArmadaDice.die_colors)
    dice_end = dice_begin + len(ArmadaDice.die_colors)
    front_dice = int(attacker.get_range('dice')[dice_begin:dice_end].sum())
    # The ship encoding should have the same dice as the template
    assert front_dice == template_front_dice

    defender = ship.Ship(name="Defender",
                         template=ship_templates["All Defense Tokens"],
                         upgrades=[],
                         player_number=2)

    encoding, world_state = make_encoding(attacker, defender, "short", agent)

    attack_state_encoding = Encodings.encodeAttackState(world_state)
    die_offset = Encodings.getAttackDiceOffset()
    # The attack state should have a roll with as many dice as the ship has.
    dice_encoding = attack_state_encoding[die_offset:die_offset +
                                          Encodings.dieEncodingSize()]
    assert int(dice_encoding.sum().item()) == front_dice
示例#2
0
    def __init__(self, with_novelty=True):
        # An extra two outputs
        # are added in for the mean and variance estimation.

        nn.Module.__init__(self)
        self.with_novelty = with_novelty
        def_in = Encodings.calculateAttackSize()
        def_out = Encodings.calculateSpendDefenseTokensSize() + 2
        self.models = nn.ModuleDict()
        self.models["def_tokens"] = self.init_fc_params(def_in, def_out)
        self.optimizers = {
            "def_tokens":
            torch.optim.Adam(self.models["def_tokens"].parameters())
        }
        if self.with_novelty:
            # We will also create two more models to use for random distillation.  The first random
            # model will remain static and the second will be learn to predict the outputs of the
            # first. The difference between the two outputs will be used to estimate the novelty of
            # the current state. If the state is new then the second model will not be able to make
            # a good prediction of the first model's outputs.  In other words, the first model
            # projects the inputs into a new latent space. The ability of the second model to
            # predict the projection into the latent space should be correlated to how similar this
            # state is to ones we have previously visited.

            # The novelty network is a clone of the corresponding network
            self.models["def_tokens_novelty"] = nn.ModuleList([
                self.init_fc_params(def_in, def_out), self.models["def_tokens"]
            ])
            self.models["def_tokens_static"] = self.init_fc_params(
                def_in, def_out)
            self.optimizers["def_tokens_novelty"] = torch.optim.Adam(
                self.models["def_tokens_novelty"].parameters())

        self.sm = nn.Softmax()
示例#3
0
def test_accuracy_encodings():
    """Test that the encoding is correct for dice targetted by an accuracy."""

    agent = LearningAgent()
    attacker = ship.Ship(name="Attacker",
                         template=ship_templates["Attacker"],
                         upgrades=[],
                         player_number=1)

    three_brace = ship.Ship(name="Double Brace",
                            template=ship_templates["Triple Brace"],
                            upgrades=[],
                            player_number=2)

    # Make a brace token red
    three_brace.spend_token('brace', ArmadaTypes.green)

    enc_three_brace, world_state = make_encoding(attacker, three_brace,
                                                 "short", agent)

    # Define the offsets for convenience
    token_begin = Encodings.getAttackTokenOffset()
    token_end = token_begin + ArmadaTypes.max_defense_tokens

    # Verify that no tokens are targeted at first
    assert 0.0 == enc_three_brace[token_begin:token_end].sum()

    # Now make a token red and target it
    three_brace.spend_token('brace', ArmadaTypes.green)
    green_acc_begin = Encodings.getAttackTokenOffset()
    green_acc_end = green_acc_begin + len(ArmadaTypes.defense_tokens)
    red_acc_begin = Encodings.getAttackTokenOffset() + len(
        ArmadaTypes.defense_tokens)
    red_acc_end = red_acc_begin + len(ArmadaTypes.defense_tokens)
    world_state.attack.accuracy_defender_token(
        ArmadaTypes.defense_tokens.index('brace'), ArmadaTypes.red)
    encoding = Encodings.encodeAttackState(world_state)

    # Verify that only the red token has the accuracy flag set
    assert encoding[red_acc_begin +
                    ArmadaTypes.defense_tokens.index('brace')].item() == 1.
    assert encoding[red_acc_begin:red_acc_end].sum().item() == 1.
    assert encoding[green_acc_begin:green_acc_end].sum().item() == 0.

    # Target both remaining green tokens
    world_state.attack.accuracy_defender_token(
        ArmadaTypes.defense_tokens.index('brace'), ArmadaTypes.green)
    world_state.attack.accuracy_defender_token(
        ArmadaTypes.defense_tokens.index('brace'), ArmadaTypes.green)
    encoding = Encodings.encodeAttackState(world_state)

    # Verify that two green and one red brace have the accuracy flag
    assert encoding[red_acc_begin +
                    ArmadaTypes.defense_tokens.index('brace')].item() == 1.
    assert encoding[red_acc_begin:red_acc_end].sum().item() == 1.
    assert encoding[green_acc_begin +
                    ArmadaTypes.defense_tokens.index('brace')].item() == 2.
    assert encoding[green_acc_begin:green_acc_end].sum().item() == 2.
示例#4
0
def test_spent_encodings():
    """Test that the encoding is correct for different defense tokens."""

    agent = LearningAgent()
    attacker = ship.Ship(name="Attacker",
                         template=ship_templates["Attacker"],
                         upgrades=[],
                         player_number=1)

    defender = ship.Ship(name="Defender",
                         template=ship_templates["All Defense Tokens"],
                         upgrades=[],
                         player_number=2)

    encoding, world_state = make_encoding(attacker, defender, "short", agent)

    # The defender and attacker come first, then the accuracied tokens, then the spent tokens
    spent_begin = 2 * ship.Ship.encodeSize() + 2 * len(
        ArmadaTypes.defense_tokens)
    spent_end = spent_begin + len(ArmadaTypes.defense_tokens)

    # Verify that no tokens are marked spent by default
    assert torch.sum(encoding[spent_begin:spent_end]) == 0.

    # Spend all of the tokens
    for tidx, ttype in enumerate(ArmadaTypes.defense_tokens):
        world_state.attack.defender_spend_token(ttype, 'green')

    encoding = Encodings.encodeAttackState(world_state)
    assert torch.sum(encoding[spent_begin:spent_end]).item() == len(
        ArmadaTypes.defense_tokens)

    # Try spending the tokens at different indices
    for tidx, ttype in enumerate(ArmadaTypes.defense_tokens):
        # Re-encode and then set the token to spent.
        attacker = ship.Ship(name="Attacker",
                             template=ship_templates["Attacker"],
                             upgrades=[],
                             player_number=1)
        defender = ship.Ship(name="Defender",
                             template=ship_templates["All Defense Tokens"],
                             upgrades=[],
                             player_number=2)
        encoding, world_state = make_encoding(attacker, defender, "short",
                                              agent)
        world_state.attack.defender_spend_token(ttype, 'green')
        encoding = Encodings.encodeAttackState(world_state)
        assert torch.sum(encoding[spent_begin:spent_end]).item() == 1.0
        assert encoding[spent_begin:spent_end][tidx].item() == 1.0
示例#5
0
def make_encoding(ship_a, ship_b, attack_range, agent):
    """This function calculates the average time to destruction when a shoots at b.

    Args:
      ship_a ((Ship, str)): Attacker and hull zone tuple.
      ship_b ((Ship, str)): Defender and hull zone tuple.
      trials (int): Number of trials in average calculation.
      range (str): Attack range.
    
    """
    roll_counts = []
    # Reset ship b for each trial
    world_state = WorldState()
    world_state.addShip(ship_a, 0)
    world_state.addShip(ship_b, 1)

    pool_colors, pool_faces = ship_a.roll("front", attack_range)
    attack = AttackState(attack_range=attack_range,
                         attacker=ship_a,
                         attacking_hull="front",
                         defender=ship_b,
                         defending_hull="front",
                         pool_colors=pool_colors,
                         pool_faces=pool_faces)
    world_state.updateAttack(attack)

    # The defense token and die locations have been reordered in the encoding, put them back to
    # their original ordering here.
    encoding = Encodings.encodeAttackState(world_state)

    return encoding, world_state
示例#6
0
    def __init__(self):
        # There is a different model for each phase.
        # Each model takes as input the world state and the attack state and output is the action

        world_size = Encodings.calculateWorldStateSize()
        attack_size = Encodings.calculateAttackSize()

        nn.Module.__init__(self)
        self.models = nn.ModuleDict()
        self.optimizers = {}

        # TODO FIXME Loop through all of the possible phases instead of just these two
        for phase_name in [
                "attack - resolve attack effects",
                "attack - spend defense tokens"
        ]:
            action_size = Encodings.calculateActionSize(phase_name)
            self.models[phase_name] = self.init_fc_params(
                world_size + attack_size, action_size)
            self.optimizers[phase_name] = torch.optim.Adam(
                self.models[phase_name].parameters())

        self.sm = nn.Softmax()
示例#7
0
def test_roll_encodings():
    """Test that the encoding is correct for dice pools and faces."""

    agent = LearningAgent()
    attacker = ship.Ship(name="Attacker",
                         template=ship_templates["Attacker"],
                         upgrades=[],
                         player_number=1)
    no_token = ship.Ship(name="No Defense Tokens",
                         template=ship_templates["No Defense Tokens"],
                         upgrades=[],
                         player_number=2)

    dice_begin = Encodings.getAttackDiceOffset()

    # Do 100 trials to ensure everything is working as expected
    _, world_state = make_encoding(attacker, no_token, "short", agent)
    for _ in range(100):
        pool_colors, pool_faces = attacker.roll("front", "short")
        attack = world_state.attack
        attack.pool_faces = pool_faces
        attack.pool_colors = pool_colors
        # Count which items are matched to check if they are all encoded
        matched_dice = [0] * len(pool_faces)
        world_state.updateAttack(attack)
        # Make a random roll and encode the attack state
        # [ color - 3, face - 6]
        enc_attack = Encodings.encodeAttackState(world_state)
        # Try to find a match for each color,face pair in the encoding
        enc_dice = enc_attack[Encodings.getAttackDiceOffset():]
        for face, color in zip(attack.pool_faces, attack.pool_colors):
            assert 0. < enc_dice[Encodings.dieOffset(color=color,
                                                     face=face)].item()
            enc_dice[Encodings.dieOffset(color=color, face=face)] -= 1
        # All dice from the pool should have been matched and there should be no more encoded
        assert sum(enc_dice) <= 0.
示例#8
0
def test_range_encodings():
    """Test that the encoding is correct for ranges."""

    agent = LearningAgent()
    attacker = ship.Ship(name="Attacker",
                         template=ship_templates["Attacker"],
                         upgrades=[],
                         player_number=1)
    no_token = ship.Ship(name="No Defense Tokens",
                         template=ship_templates["No Defense Tokens"],
                         upgrades=[],
                         player_number=2)

    range_begin = Encodings.getAttackRangeOffset()
    for offset, attack_range in enumerate(ArmadaTypes.ranges):
        enc_attack = make_encoding(attacker, no_token, attack_range, agent)[0]
        assert torch.sum(enc_attack[range_begin:range_begin +
                                    len(ArmadaTypes.ranges)]) == 1
        assert 1.0 == enc_attack[range_begin + offset].item()
    def __init__(self, model=None):
        """Initialize the simple agent with a couple of simple state handlers.
        
        Args:
            model (torch.nn.Module or None): If None this agent will pass for all supported states
        """
        handler = {
            "ship phase - attack - resolve attack effects":
            self.resolveAttackEffects,
            "ship phase - attack - spend defense tokens":
            self.spendDefenseTokens
        }
        super(LearningAgent, self).__init__(handler)
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        self.model = model
        if None != self.model:
            self.model = self.model.to(self.device)
            self.model.eval()

        self.attack_enc_size = Encodings.calculateAttackSize()
        self.memory = []
        self.remembering = False
示例#10
0
def update_lifetime_network(lifenet, batch, labels, optimizer, eval_only=False):
    """Do a forward and backward pass through the given lifetime network.

    Args:
        lifenet (torch.nn.Module): Trainable torch model
        batch (torch.tensor)     : Training batch
        labels (torch.tensor)    : Training labels
        optimizer (torch.nn.Optimizer) : Optimizer for lifenet parameters.
        eval_only (bool)         : Only evaluate, don't update parameters.
    Returns:
        batch error              : Average absolute error for this batch
    """
    if eval_only:
        lifenet.eval()
    # Forward through the prediction network
    prediction = lifenet.forward(batch)
    # Loss is the lifetime prediction error
    # The output cannot be negative, run through a ReLU to clean that up
    #f = torch.nn.ReLU()
    #epsilon = 0.001
    #normed_predictions = f(prediction[0]) + epsilon
    with torch.no_grad():
        error = (prediction - labels).abs().mean().item()

    with torch.no_grad():
        errors = (prediction - labels).abs()
        for i in range(errors.size(0)):
            # Debug on crazy errors or nan values.
            if errors[i] > 1000 or errors[i] != errors[i]:
                # This is messy debugging code, but sometimes a test may fail after bug
                # introductions so it is helpful to leave this in to speed up debugging.
                world_size = Encodings.calculateWorldStateSize()
                world_encoding = batch[i,:world_size]
                phase_name = ArmadaPhases.main_phases[int(world_encoding[1].item())]
                sub_phase_name = ArmadaPhases.sub_phases[phase_name][int(world_encoding[2].item())]
                print(f"Error {i} is {errors[i]}")
                print(f"\tRound {world_encoding[0].item()}")
                print(f"\tSubphase {sub_phase_name}")
                action_size = Encodings.calculateActionSize(sub_phase_name)
                attack_size = Encodings.calculateAttackSize()
                action_encoding = batch[i,world_size:world_size + action_size]
                attack_state_encoding = batch[i,world_size + action_size:]
                if "attack - resolve attack effects" == sub_phase_name:
                    print(f"\tattack effect encoding is {action_encoding}")
                elif "attack - spend defense tokens" == sub_phase_name:
                    print(f"\tspend defense token encoding is {action_encoding}")
                else:
                    print("Cannot print information about {}".format(sub_phase_name))
                defender = Ship(name="Defender", player_number=1,
                                encoding=attack_state_encoding[:Ship.encodeSize()])
                attacker = Ship(name="Attacker", player_number=1,
                                encoding=attack_state_encoding[Ship.encodeSize():2 * Ship.encodeSize()])
                # print(f"\tAttack state encoding is {attack_state_encoding}")
                print("\tAttacker is {}".format(attacker))
                print("\tDefender is {}".format(defender))
                # TODO FIXME Enough dice in a pool seems to end a ship, but unless the pools are
                # incredibly large this doesn't seem to be happening. Damage does not seem to be
                # accumlating between rounds.
                die_offset = Encodings.getAttackDiceOffset()
                dice_encoding = attack_state_encoding[die_offset:die_offset + Encodings.dieEncodingSize()]
                print("\tDice are {}".format(dice_encoding))
                print(f"\tLabel is {labels[i]}")


    # Normal distribution:
    #normal = torch.distributions.normal.Normal(predict_tensor[0], predict_tensor[1])
    #loss = -normal.log_prob(labels)
    # Poisson distribution (works well)
    #poisson = torch.distributions.poisson.Poisson(normed_predictions)
    #loss = -poisson.log_prob(labels)
    # Plain old MSE error
    #loss_fn = torch.nn.MSELoss()
    #loss = loss_fn(prediction, labels)
    # Absolute error
    loss_fn = torch.nn.L1Loss()
    loss = loss_fn(prediction, labels)

    if not eval_only:
        optimizer.zero_grad()
        loss.sum().backward()
        optimizer.step()
    else:
        lifenet.train()
    return error
示例#11
0
def test_policy_learning(spend_defense_tokens_model, resolve_attack_effects_model):
    """Train a model to produce a better than random choice policy for defense tokens.

    The spend_defense_tokens_model will be used to determine the quality of this network's output.
    There will not be an update step as in reinforcement learning, this is just testing the
    mechanism.

    Returns:
        (nn.module, nn.module,): The 'resolve attack effects' and 'spend defense tokens' models.
    """
    def_tokens_model, errors, eval_errors = spend_defense_tokens_model
    def_tokens_model.eval()
    res_attack_model, errors, eval_errors = resolve_attack_effects_model
    res_attack_model.eval()
    prediction_models = {
        "attack - spend defense tokens": def_tokens_model.eval(),
        "attack - resolve attack effects": res_attack_model.eval()
    }
    # Do the training. Use the prediction model lifetime to create the loss target. The loss
    # will be the difference between the max possible round and the predicted round.
    # For the defense token model the higher round is better, for the attack effect model a
    # lower round is better.
    loss_fn = {
        "attack - spend defense tokens": lambda predictions: 7.0 - predictions,
        "attack - resolve attack effects": lambda predictions: predictions - 1.
    }

    # Generate a new learning model
    learning_agent = LearningAgent(SeparatePhaseModel())
    # TODO FIXME The learning agent doesn't really have a training mode
    learning_agent.model.train()
    random_agent = RandomAgent()

    training_ships = ["All Defense Tokens", "All Defense Tokens",
                      "Imperial II-class Star Destroyer", "MC80 Command Cruiser",
                      "Assault Frigate Mark II A", "No Shield Ship", "One Shield Ship",
                      "Mega Die Ship"]
    defenders = []
    attackers = []
    for name in training_ships:
        attackers.append(Ship(name=name, template=ship_templates[name],
                              upgrades=[], player_number=1, device='cpu'))
    for name in training_ships:
        defenders.append(Ship(name=name, template=ship_templates[name],
                              upgrades=[], player_number=2, device='cpu'))

    batch_size = 32
    # Remember the training loss values to test for improvement
    losses = {}
    for subphase in ["attack - spend defense tokens", "attack - resolve attack effects"]:
        losses[subphase] = []
    # This gets samples to use for training. We will use a random agent to generate the states.
    # In reinforcement learning the agent would alternate between random actions and actions
    # from the learning agent to balance exploration of the state space with exploitation of
    # learned behavior.
    samples = get_n_examples(
        n_examples=1000, ship_a=attackers, ship_b=defenders, agent=random_agent)

    # Get a batch for each subphase
    for subphase in ["attack - spend defense tokens", "attack - resolve attack effects"]:
        # The learning_agent will take in the world state and attack state and will produce a new
        # action encoding. Then the prediction network will take in a new tuple with this action and
        # predict a final round. The difference between the random action and the final round with
        # the network's action is the reward.
        world_size = Encodings.calculateWorldStateSize()
        action_size = Encodings.calculateActionSize(subphase)
        attack_size = Encodings.calculateAttackSize()
        for batch, lifetimes in collect_attack_batches(batch_size=batch_size, attacks=samples, subphase=subphase):
            batch = batch.cuda()
            # The learning agent takes in the world state along with the action as the input tensor.
            new_batch = torch.cat((batch[:,:world_size], batch[:,world_size + action_size:]), dim=1).cuda()
            # TODO FIXME Just make a forward function that takes in a phase name
            new_actions = learning_agent.model.models[subphase](new_batch)
            new_action_state = torch.cat((batch[:,:world_size], new_actions, batch[:,world_size + action_size:]), dim=1)
            action_result = prediction_models[subphase](new_action_state)
            loss = loss_fn[subphase](action_result)
            learning_agent.model.get_optimizer(subphase).zero_grad()
            loss.sum().backward()
            learning_agent.model.get_optimizer(subphase).step()
            with torch.no_grad():
                losses[subphase].append(loss.mean().item())
            # In reinforcement learning there would also be a phase where the prediction models are
            # updated.
    for subphase in ["attack - spend defense tokens", "attack - resolve attack effects"]:
        print(f"losses for {subphase} start with {losses[subphase][0:5]} and end with {losses[subphase][-5:]}")
        assert losses[subphase][-1] < losses[subphase][0]
    # TODO FIXME HERE See if the policy networks produce better results than random actions
    learning_agent.model.models.eval()
示例#12
0
def spend_defense_tokens_model(create_spend_defense_tokens_dataset):
    """Train some basic lifetime prediction models.

    Create simple networks that predict defending ship lifetimes during the 'spend defese tokens'
    phase and during the 'resolve attack effects' phase.

    Returns:
        (nn.module, [float], [float]): The model, training errors per epoch, and eval errors per
                                       epoch.
    """
    attacks, eval_attacks = create_spend_defense_tokens_dataset

    target_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Lifetime network B predicts lifetimes given a state and action pair in the 'spend defense
    # tokens' subphase.
    phase_name = "attack - spend defense tokens"
    world_size = Encodings.calculateWorldStateSize()
    action_size = Encodings.calculateActionSize(phase_name)
    attack_size = Encodings.calculateAttackSize()
    input_size = world_size + action_size + attack_size
    network = torch.nn.Sequential(
        torch.nn.BatchNorm1d(input_size),
        torch.nn.Linear(input_size, 2 * input_size),
        torch.nn.ELU(),
        torch.nn.Linear(2 * input_size, 4 * input_size),
        torch.nn.ELU(),
        torch.nn.Linear(4 * input_size, 2 * input_size),
        torch.nn.ELU(),
        torch.nn.Linear(2 * input_size, 1))
    network.to(target_device)
    optimizer = torch.optim.Adam(network.parameters(), lr=0.0005)
    batch_size = 32
    batch = torch.Tensor(batch_size, input_size).to(target_device)

    # Keep track of the errors for the purpose of this test
    errors = []
    eval_errors = []

    # Evaluate before training and every epoch
    for batch in eval_attacks:
        eval_data = batch[0].to(target_device)
        eval_labels = batch[1].to(target_device)
        eval_errors.append(update_lifetime_network(network, eval_data,
                                                   eval_labels, None, True))
    # Train with all of the data for 10 epochs
    for epoch in range(10):
        print("Training resolve_spend_defense_tokens_model epoch {}".format(epoch))
        train_batches = 0
        for batch in attacks:
            train_data = batch[0].to(target_device)
            train_labels = batch[1].to(target_device)
            errors.append(update_lifetime_network(network, train_data,
                                                  train_labels, optimizer))
            train_batches += 1
        print("Finished epoch with {} batches.".format(train_batches))
        # Evaluate every epoch
        for batch in eval_attacks:
            eval_data = batch[0].to(target_device)
            eval_labels = batch[1].to(target_device)
            eval_errors.append(update_lifetime_network(network, eval_data,
                                                       eval_labels, None, True))

    return network, errors, eval_errors
示例#13
0
def resolve_attack_effects_model(create_attack_effects_dataset):
    """Train some basic lifetime prediction models.

    Create a simple network that predict defending ship lifetimes during the 'resolve attack
    effects' phase. Also creates logs of training and evaluation loss.

    Returns:
        (nn.module, [float], [float]): The model, training errors per epoch, and eval errors per
                                       epoch.
    """
    attacks, eval_attacks = create_attack_effects_dataset

    target_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Lifetime network A predicts lifetimes given a state and action pair in the 'resolve attack
    # effects' subphase.
    phase_name = "attack - resolve attack effects"
    world_size = Encodings.calculateWorldStateSize()
    action_size = Encodings.calculateActionSize(phase_name)
    attack_size = Encodings.calculateAttackSize()
    input_size = world_size + action_size + attack_size
    # The network size was made large enough that training plateaued to a stable value
    # If the network is too large it has a tendency to start fitting to specific cases.
    # Batchnorm doesn't seem to help out much with this network and task.
    network = torch.nn.Sequential(
        torch.nn.BatchNorm1d(input_size),
        torch.nn.Linear(input_size, 2 * input_size),
        torch.nn.ELU(),
        torch.nn.Linear(2 * input_size, 4 * input_size),
        torch.nn.ELU(),
        #torch.nn.Dropout(),
        #torch.nn.BatchNorm1d(4 * input_size),
        torch.nn.Linear(4 * input_size, 2 * input_size),
        torch.nn.ELU(),
        torch.nn.Linear(2 * input_size, 1))
    network.to(target_device)
    # Higher learning rates lead to a lot of instability in the training.
    optimizer = torch.optim.Adam(network.parameters(), lr=0.0005)

    # Keep track of the errors for the purpose of this test
    errors = []
    eval_errors = []

    # Evaluate before training and every epoch
    for batch in eval_attacks:
        eval_data = batch[0].to(target_device)
        eval_labels = batch[1].to(target_device)
        eval_errors.append(update_lifetime_network(network, eval_data,
                                                   eval_labels, None, True))
    # Train with all of the data for 10 epochs
    for epoch in range(10):
        print("Training resolve_attack_effects_model epoch {}".format(epoch))
        epoch_samples = 0
        train_batches = 0
        for batch in attacks:
            train_data = batch[0].to(target_device)
            train_labels = batch[1].to(target_device)
            errors.append(update_lifetime_network(network, train_data,
                                                  train_labels, optimizer))
            epoch_samples += batch[0].size(0)
            train_batches += 1
        print("Finished epoch with {} batches.".format(train_batches))

        # Evaluate every epoch
        for batch in eval_attacks:
            eval_data = batch[0].to(target_device)
            eval_labels = batch[1].to(target_device)
            eval_errors.append(update_lifetime_network(network, eval_data,
                                                       eval_labels, None, True))

    return network, errors, eval_errors
    def spendDefenseTokens(self, world_state):
        """
        Args:
            world_state (table)   : Contains the list of ships and dice pool.
            current_step (string) : This function only operates on the "spend defense tokens" step.
        Returns:
            (str, varies) : A tuple of the token name to spend and the token targets of die to target
                            with the token.  If no token will be spent then both values of the tuple are
                            None and if the defense token has no target then the second value is None.
                            For evade tokens the second value is the index of the die to target.
                            For redirect tokens the second value is a tuple of (str, int) for the
                            target hull zone and the amount of damage to direct to that hull zone.
        """
        # TODO FIXME The return type is totally messed up, a list of actions should be allowable
        # from here
        # We only handle one sub-phase in this function
        assert world_state.sub_phase == "attack - spend defense tokens"

        if None == self.model:
            # Return no action
            return []
        # Encode the state, forward through the network, decode the result, and return the result.
        as_enc = Encodings.encodeAttackState(world_state)
        as_enc = as_enc.to(self.device)
        if not self.model.with_novelty:
            # Forward through the policy net randomly, otherwise return random actions
            if self.randprob >= random.random():
                # Take a random action
                action = self.random_agent("def_tokens", as_enc)[0]
            else:
                action = self.model.forward("def_tokens", as_enc)[0]
            # Remember this state action pair if in memory mode
            if self.remembering:
                self.memory.append((world_state.attack, as_enc, action))
        else:
            action, novelty = self.model.forward("def_tokens", as_enc)
            # Remove the batch dimension
            action = action[0]
            novelty = novelty[0]
            # Remember this state action pair if in memory mode
            if self.remembering:
                self.memory.append(
                    (world_state.attack, as_enc, action, novelty))
        # Don't return the lifetime prediction (used in train_defense_tokens.py)
        #with torch.no_grad():
        #    action = torch.round(action[:Encodings.calculateSpendDefenseTokensSize()])
        # Determine which tokens should be spent. Spend any above a threshold.
        with torch.no_grad():
            spend_green = action[:len(ArmadaTypes.defense_tokens)] > math.log(
                0.5)
            spend_red = action[len(ArmadaTypes.defense_tokens):2 *
                               len(ArmadaTypes.defense_tokens)] > math.log(0.5)
            spent_tokens = spend_green + spend_red
            # Return now if no token will be spent
            if (0 == spent_tokens).all():
                return []

        # Handle the tokens

        # First check for validity. If the selected token isn't valid then use no token.
        # TODO Perhaps this should be penalized in some way
        green_idx, green_len = Ship.get_index('defense_tokens_green')
        red_idx, red_len = Ship.get_index('defense_tokens_red')
        defender_green_tokens = world_state.attack.defender.encoding[
            green_idx:green_idx + green_len]
        defender_red_tokens = world_state.attack.defender.encoding[
            red_idx:red_idx + red_len]
        if (spend_green > defender_green_tokens).any() or (
                spend_red > defender_red_tokens).any():
            return []

        # Verify that these tokens have not been the target of an accuracy and that they can be
        # spent
        for idx in len(ArmadaTypes.defense_tokens):
            if 0 < world_state.attack.accuracy_tokens[
                    idx] and 0 < spent_tokens[idx]:
                return []
        # TODO FIXME Those last two checks (for token availability and non-accuracy status) should
        # be enforced via a fixed input to the network that suppresses token outputs if they are not
        # available. This would make learning simpler.

        actions = []
        # Handle the token, decoding the returned action based upon the token type.
        evade_index = ArmadaTypes.defense_tokens.index("evade")
        if 0 < defender_green_tokens[evade_index].item(
        ) + defender_red_tokens[evade_index].item():
            begin = Encodings.getSpendDefenseTokensEvadeOffset()
            end = begin + Encodings.max_die_slots
            # Get the index of the maximum die response
            _, die_idx = action[begin:end].max(0)
            # Check for an invalid response from the agent
            # TODO Perhaps this should be penalized in some way
            # TODO FIXME This should also be supressed through die availability.
            # TODO FIXME Also handle extreme range with 2 die targets
            if len(die_slots) <= die_idx.item():
                pass
            else:
                color = ArmadaTypes.green if 0 < defender_green_tokens[
                    evade_index].item() else ArmadaTypes.red
                src_die_slot = die_slots[die_idx.item()]
                actions.append(("evade", (src_die_slot)))

        # TODO This only supports redirecting to a single hull zone currently
        redir_index = ArmadaTypes.defense_tokens.index("redirect")
        if 0 < spent_tokens[ArmadaTypes.defense_tokens.index("redirect")]:
            begin = Encodings.getSpendDefenseTokensRedirectOffset()
            end = begin + len(ArmadaTypes.hull_zones)
            # The encoding has a value for each hull zone. We should check if an upgrade allows the
            # defender to redirect to nonadjacent or multiple hull zones, but for now we will just
            # handle the base case. TODO
            adj_hulls = world_state.attack.defender.adjacent_zones(
                world_state.attack.defending_hull)

            # Redirect to whichever hull has the greater value TODO
            redir_hull = None
            redir_amount = 0

            for hull in adj_hulls:
                hull_redir_amount = round(
                    action[begin + ArmadaTypes.hull_zones.index(hull)].item())
                if hull_redir_amount > redir_amount:
                    redir_hull = hull
                    redir_amount = hull_redir_amount

            # Make sure there is actual redirection
            if None != redir_hull:
                color = ArmadaTypes.green if 0 < defender_green_tokens[
                    redir_index].item() else ArmadaTypes.red
                actions.append(("redirect", (redir_hull, redir_amount)))

        # Other defense tokens with no targets
        for tindx, token_type in enumerate(ArmadaTypes.defense_tokens):
            if token_type not in ["evade", "redirect"]:
                if 0 < spent_tokens[tindx]:
                    color = ArmadaTypes.green if 0 < defender_green_tokens[
                        tindx].item() else ArmadaTypes.red
                    actions.append((token_type, (color, None)))

        return actions
示例#15
0
def test_resolve_attack_effects_model(resolve_attack_effects_model):
    """Test basic network learning loop.

    Test lifetime predictions during the resolve attack effects phase.
    """
    network, errors, eval_errors = resolve_attack_effects_model
    network.eval()
    phase_name = "attack - resolve attack effects"
    world_size = Encodings.calculateWorldStateSize()
    action_size = Encodings.calculateActionSize(phase_name)
    attack_size = Encodings.calculateAttackSize()
    input_size = world_size + action_size + attack_size

    # First verify that errors decreased during training.
    # print("Errors for A were {}".format(errors))
    print("Eval errors for A were {}".format(eval_errors))
    assert eval_errors[0] > eval_errors[-1]

    # Let's examine predictions for different ranges and hull zones.
    target_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    batch_size = 4
    batch = torch.Tensor(batch_size, input_size).to(target_device)

    # Let's examine predictions for different dice pools and spent defense tokens.
    # Go through the following scenarios:
    # 1.1 An attack upon a ship with only 1 hull remaining
    # 1.2 The same dice pool but on a ship with full hull
    # 1.3 A dice pool with only blank dice
    # 1.4 A dice pool with only blanks when attacking at long range.

    # Create a state from resolve attack effects and an empty action.
    world_state = WorldState()
    world_state.round = 1
    ship_a = Ship(name="Ship A", template=ship_templates["All Defense Tokens"], upgrades=[], player_number=1)
    ship_b = Ship(name="Ship B", template=ship_templates["All Defense Tokens"], upgrades=[], player_number=2)
    world_state.addShip(ship_a, 0)
    world_state.addShip(ship_b, 1)
    pool_colors, pool_faces = ['black'] * 4, ['hit_crit'] * 4
    world_state.setPhase("ship phase", "attack - resolve attack effects")
    ship_b.set('damage', ship_b.get('hull') - 1)
    attack = AttackState('short', ship_a, 'left', ship_b, 'front', pool_colors, pool_faces)
    world_state.updateAttack(attack)
    action_encoding = torch.cat((Encodings.encodeWorldState(world_state),
                                 Encodings.encodeAction(world_state.sub_phase, None)))
    state_encoding = Encodings.encodeAttackState(world_state)
    batch[0] = torch.cat(
        (action_encoding.to(target_device), state_encoding.to(target_device)))

    # Same dice pool but the defender has full hull
    ship_b.set('damage', 0)
    attack = AttackState('short', ship_a, 'left', ship_b, 'front', pool_colors, pool_faces)
    world_state.updateAttack(attack)
    action_encoding = torch.cat((Encodings.encodeWorldState(world_state),
                                 Encodings.encodeAction(world_state.sub_phase, None)))
    state_encoding = Encodings.encodeAttackState(world_state)
    batch[1] = torch.cat(
        (action_encoding.to(target_device), state_encoding.to(target_device)))

    # Full hull and all blanks
    pool_colors, pool_faces = ['black'] * 4, ['blank'] * 4
    world_state.setPhase("ship phase", "attack - resolve attack effects")
    attack = AttackState('short', ship_a, 'left', ship_b, 'front', pool_colors, pool_faces)
    world_state.updateAttack(attack)
    state_encoding = Encodings.encodeAttackState(world_state)
    batch[2] = torch.cat(
        (action_encoding.to(target_device), state_encoding.to(target_device)))

    # Full hull, all blanks, firing at red range
    pool_colors, pool_faces = ['red'] * 2, ['blank'] * 2
    world_state.setPhase("ship phase", "attack - resolve attack effects")
    attack = AttackState('long', ship_a, 'left', ship_b, 'front', pool_colors, pool_faces)
    world_state.updateAttack(attack)
    state_encoding = Encodings.encodeAttackState(world_state)
    batch[3] = torch.cat(
        (action_encoding.to(target_device), state_encoding.to(target_device)))

    lifetime_out = network(batch)
    print("super cool attack effects round estimates are {}".format(lifetime_out))

    # The lifetimes should go up sequentially with the above scenarios.
    # However if the ship won't be destroyed the NN can't make an accurate relative number so be
    # lenient once lifetimes go above round 6. The first scenario should result in destruction
    # however.
    assert(lifetime_out[0].item() < 6)
    for i in range(batch.size(0) - 1):
        assert(lifetime_out[i].item() < lifetime_out[i+1].item() or 
                (lifetime_out[i].item() > 6. and lifetime_out[i+1].item() > 6.))
示例#16
0
def test_defense_tokens_model(spend_defense_tokens_model):
    """Test basic network learning loop.

    Test lifetime predictions during the spend defense tokens phase.
    """
    network, errors, eval_errors = spend_defense_tokens_model
    network.eval()
    phase_name = "attack - spend defense tokens"
    world_size = Encodings.calculateWorldStateSize()
    action_size = Encodings.calculateActionSize(phase_name)
    attack_size = Encodings.calculateAttackSize()
    input_size = world_size + action_size + attack_size

    target_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    batch_size = 32
    batch = torch.Tensor(batch_size, input_size).to(target_device)

    print("Eval errors for B were {}".format(eval_errors))
    #print("First and last errors in b are {} and {}".format(eval_errors[0], eval_errors[-1]))

    assert eval_errors[0] > eval_errors[-1]

    # Let's examine predictions for different dice pools and spent defense tokens.
    # Go through the following scenarios:
    # 1.1 An attack with more than enough damage to destroy the ship
    # 1.2 The same attack but a brace that would prevent destruction
    # 1.3 The same attack but a redirect that would prevent destruction
    # Result: 1.1 should have lower lifetime than 1.2 and 1.3
    # 2.1 An attack that can barely destroy the ship
    # 2.2 An attack that barely will not destroy the ship
    # Result: 2.1 should have lower lifetime than 2.2.
    # Ideally 1.1 and 2.1 would predict the current round.
    world_state = WorldState()
    world_state.round = 1
    ship_a = Ship(name="Ship A", template=ship_templates["All Defense Tokens"], upgrades=[], player_number=1)
    ship_b = Ship(name="Ship B", template=ship_templates["All Defense Tokens"], upgrades=[], player_number=2)
    world_state.addShip(ship_a, 0)
    world_state.addShip(ship_b, 1)
    pool_colors, pool_faces = ['black'] * 4, ['hit_crit'] * 4
    world_state.setPhase("ship phase", phase_name)
    # Set the front hull zone to 2 shields
    ship_b.get_range('shields')[ArmadaTypes.hull_zones.index('front')] = 2
    # Set the hull to 3 (by assigning damage to reduce the remaining hull to 3)
    ship_b.set('damage', ship_b.get('hull') - 3)
    attack = AttackState('short', ship_a, 'left', ship_b, 'front', pool_colors, pool_faces)
    world_state.updateAttack(attack)
    action_encoding = torch.cat((Encodings.encodeWorldState(world_state),
                                 Encodings.encodeAction(world_state.sub_phase, None)))
    state_encoding = Encodings.encodeAttackState(world_state)
    batch[0] = torch.cat(
        (action_encoding.to(target_device), state_encoding.to(target_device)))

    action = [("brace", (ArmadaTypes.green, None))]
    action_encoding = torch.cat((Encodings.encodeWorldState(world_state),
                                 Encodings.encodeAction(world_state.sub_phase, action)))
    state_encoding = Encodings.encodeAttackState(world_state)
    batch[1] = torch.cat(
        (action_encoding.to(target_device), state_encoding.to(target_device)))

    world_state = WorldState()
    world_state.round = 1
    ship_a = Ship(name="Ship A", template=ship_templates["All Defense Tokens"], upgrades=[], player_number=1)
    ship_b = Ship(name="Ship B", template=ship_templates["All Defense Tokens"], upgrades=[], player_number=2)
    world_state.addShip(ship_a, 0)
    world_state.addShip(ship_b, 1)
    pool_colors, pool_faces = ['black'] * 4, ['hit_crit'] * 2 + ['hit'] * 2
    world_state.setPhase("ship phase", phase_name)
    # Set the front hull zone to 2 shields
    ship_b.get_range('shields')[ArmadaTypes.hull_zones.index('front')] = 2
    # Set the hull to 3 (by assigning damage to reduce the remaining hull to 3)
    ship_b.set('damage', ship_b.get('hull') - 3)
    attack = AttackState('short', ship_a, 'left', ship_b, 'front', pool_colors, pool_faces)
    world_state.updateAttack(attack)

    action = [("redirect", (ArmadaTypes.green, [('left', 4)]))]
    action_encoding = torch.cat((Encodings.encodeWorldState(world_state),
                                 Encodings.encodeAction(world_state.sub_phase, action)))
    state_encoding = Encodings.encodeAttackState(world_state)
    batch[2] = torch.cat(
        (action_encoding.to(target_device), state_encoding.to(target_device)))

    round_status = network(batch[:3])
    print("super cool estimated rounds of destructions are {}".format(round_status[:3]))

    # Using no defense token results in destruction, the final round should be less
    assert(round_status[0].item() < round_status[1].item())
    assert(round_status[0].item() < round_status[2].item())
def collect_attack_batches(batch_size, attacks, subphase):
    """A generator to collect training batches from a list of attack logs.

    Collect all of the actions taken during the resolve attack effects stage of a trial and
    associate them with the round number when the trial ended.  Only sample a single action-state
    pair from each trial into the training batch though. This avoids the network simply memorizing
    the output of specific scenarios in the event of fairly unique events (for example a certain
    unlikely dice roll or combination of shields and hull in the defending ship).

    Args:
        batch_size (int)            : Maximum batch size.
        attacks (List[List[tuples]]): Each sublist is all of the states and actions from a sequence.
        subphase (str)              : Name of the subphase where state/action pairs are collected.
    Returns:
        (batch, labels) : Tuple of the batch and labels.
    """
    # TODO FIXME This function can be used as an dataset iterator for a multithreaded dataloader so
    # the batch and labels should not be passed in. Instead this function must create new tensors
    # for each batch.
    world_size = Encodings.calculateWorldStateSize()
    action_size = Encodings.calculateActionSize(subphase)
    attack_size = Encodings.calculateAttackSize()
    input_size = world_size + action_size + attack_size
    batch = torch.zeros(batch_size, input_size)
    labels = torch.zeros(batch_size, 1)

    # Variables for collection
    # collect_state records what we are doing inside of the sample loop
    collect_state = 0
    # The state and (state, action) pairs collected from the current trial
    last_state = None
    state_actions_attacks = []
    # Counter for the training target
    attack_count = 0
    cur_sample = 0
    last_round = 0
    for sequence in attacks:
        # We only collect a single sample per sequence. Collect all of the state action pairs of
        # interest first and then select a single one to use.
        state_action_pairs = []
        for attack in sequence:
            if 'state' == attack[0]:
                last_round = attack[1].round
                if attack[1].sub_phase == subphase:
                    last_state = attack[1]
                    collect_state = 1
                else:
                    # Not in the desired subphase and the attack trial is not complete.
                    # Waiting for the next attack in the trial or for the trial to end.
                    collect_state = 2
            elif 'action' == attack[0] and 1 == collect_state:
                # Collect the actions associated with last_state. The attack count will later be
                # corrected to be the number of total attacks from that state rather than the current
                # attack.
                state_action_pairs.append((last_state, attack[1]))
        # Collect a single sample (as long as one was present)
        if 0 < len(state_action_pairs):
            selected = random.choice(state_action_pairs)
            # The training label is the final round where the ship was destroyed, or 7 if the ship
            # was not destroyed in the 6 game rounds.
            labels[cur_sample] = last_round
            world_size = Encodings.calculateWorldStateSize()
            action_size = Encodings.calculateActionSize(selected[0].sub_phase)
            attack_size = Encodings.calculateAttackSize()
            Encodings.encodeWorldState(world_state=selected[0],
                                       encoding=batch[cur_sample, :world_size])
            Encodings.encodeAction(
                subphase=selected[0].sub_phase,
                action_list=selected[1],
                encoding=batch[cur_sample,
                               world_size:world_size + action_size])
            Encodings.encodeAttackState(world_state=selected[0],
                                        encoding=batch[cur_sample, world_size +
                                                       action_size:])
            cur_sample += 1
        # When a full batch is collected return it immediately.
        if cur_sample == batch_size:
            yield ((batch[:cur_sample], labels[:cur_sample]))
            # Make new buffers to store training data
            batch = torch.zeros(batch_size, input_size)
            labels = torch.zeros(batch_size, 1)
            cur_sample = 0

    # If there are leftover samples that did not fill a batch still return them.
    if 0 < cur_sample:
        yield ((batch[:cur_sample], labels[:cur_sample]))