def test_resolve_attack_effects_model(resolve_attack_effects_model):
    """Test basic network learning loop.

    Test lifetime predictions during the resolve attack effects phase.
    """
    network, errors, eval_errors = resolve_attack_effects_model
    network.eval()
    phase_name = "attack - resolve attack effects"
    world_size = Encodings.calculateWorldStateSize()
    action_size = Encodings.calculateActionSize(phase_name)
    attack_size = Encodings.calculateAttackSize()
    input_size = world_size + action_size + attack_size

    # First verify that errors decreased during training.
    # print("Errors for A were {}".format(errors))
    print("Eval errors for A were {}".format(eval_errors))
    assert eval_errors[0] > eval_errors[-1]

    # Let's examine predictions for different ranges and hull zones.
    target_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    batch_size = 4
    batch = torch.Tensor(batch_size, input_size).to(target_device)

    # Let's examine predictions for different dice pools and spent defense tokens.
    # Go through the following scenarios:
    # 1.1 An attack upon a ship with only 1 hull remaining
    # 1.2 The same dice pool but on a ship with full hull
    # 1.3 A dice pool with only blank dice
    # 1.4 A dice pool with only blanks when attacking at long range.

    # Create a state from resolve attack effects and an empty action.
    world_state = WorldState()
    world_state.round = 1
    ship_a = Ship(name="Ship A", template=ship_templates["All Defense Tokens"], upgrades=[], player_number=1)
    ship_b = Ship(name="Ship B", template=ship_templates["All Defense Tokens"], upgrades=[], player_number=2)
    world_state.addShip(ship_a, 0)
    world_state.addShip(ship_b, 1)
    pool_colors, pool_faces = ['black'] * 4, ['hit_crit'] * 4
    world_state.setPhase("ship phase", "attack - resolve attack effects")
    ship_b.set('damage', ship_b.get('hull') - 1)
    attack = AttackState('short', ship_a, 'left', ship_b, 'front', pool_colors, pool_faces)
    world_state.updateAttack(attack)
    action_encoding = torch.cat((Encodings.encodeWorldState(world_state),
                                 Encodings.encodeAction(world_state.sub_phase, None)))
    state_encoding = Encodings.encodeAttackState(world_state)
    batch[0] = torch.cat(
        (action_encoding.to(target_device), state_encoding.to(target_device)))

    # Same dice pool but the defender has full hull
    ship_b.set('damage', 0)
    attack = AttackState('short', ship_a, 'left', ship_b, 'front', pool_colors, pool_faces)
    world_state.updateAttack(attack)
    action_encoding = torch.cat((Encodings.encodeWorldState(world_state),
                                 Encodings.encodeAction(world_state.sub_phase, None)))
    state_encoding = Encodings.encodeAttackState(world_state)
    batch[1] = torch.cat(
        (action_encoding.to(target_device), state_encoding.to(target_device)))

    # Full hull and all blanks
    pool_colors, pool_faces = ['black'] * 4, ['blank'] * 4
    world_state.setPhase("ship phase", "attack - resolve attack effects")
    attack = AttackState('short', ship_a, 'left', ship_b, 'front', pool_colors, pool_faces)
    world_state.updateAttack(attack)
    state_encoding = Encodings.encodeAttackState(world_state)
    batch[2] = torch.cat(
        (action_encoding.to(target_device), state_encoding.to(target_device)))

    # Full hull, all blanks, firing at red range
    pool_colors, pool_faces = ['red'] * 2, ['blank'] * 2
    world_state.setPhase("ship phase", "attack - resolve attack effects")
    attack = AttackState('long', ship_a, 'left', ship_b, 'front', pool_colors, pool_faces)
    world_state.updateAttack(attack)
    state_encoding = Encodings.encodeAttackState(world_state)
    batch[3] = torch.cat(
        (action_encoding.to(target_device), state_encoding.to(target_device)))

    lifetime_out = network(batch)
    print("super cool attack effects round estimates are {}".format(lifetime_out))

    # The lifetimes should go up sequentially with the above scenarios.
    # However if the ship won't be destroyed the NN can't make an accurate relative number so be
    # lenient once lifetimes go above round 6. The first scenario should result in destruction
    # however.
    assert(lifetime_out[0].item() < 6)
    for i in range(batch.size(0) - 1):
        assert(lifetime_out[i].item() < lifetime_out[i+1].item() or 
                (lifetime_out[i].item() > 6. and lifetime_out[i+1].item() > 6.))
def test_defense_tokens_model(spend_defense_tokens_model):
    """Test basic network learning loop.

    Test lifetime predictions during the spend defense tokens phase.
    """
    network, errors, eval_errors = spend_defense_tokens_model
    network.eval()
    phase_name = "attack - spend defense tokens"
    world_size = Encodings.calculateWorldStateSize()
    action_size = Encodings.calculateActionSize(phase_name)
    attack_size = Encodings.calculateAttackSize()
    input_size = world_size + action_size + attack_size

    target_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    batch_size = 32
    batch = torch.Tensor(batch_size, input_size).to(target_device)

    print("Eval errors for B were {}".format(eval_errors))
    #print("First and last errors in b are {} and {}".format(eval_errors[0], eval_errors[-1]))

    assert eval_errors[0] > eval_errors[-1]

    # Let's examine predictions for different dice pools and spent defense tokens.
    # Go through the following scenarios:
    # 1.1 An attack with more than enough damage to destroy the ship
    # 1.2 The same attack but a brace that would prevent destruction
    # 1.3 The same attack but a redirect that would prevent destruction
    # Result: 1.1 should have lower lifetime than 1.2 and 1.3
    # 2.1 An attack that can barely destroy the ship
    # 2.2 An attack that barely will not destroy the ship
    # Result: 2.1 should have lower lifetime than 2.2.
    # Ideally 1.1 and 2.1 would predict the current round.
    world_state = WorldState()
    world_state.round = 1
    ship_a = Ship(name="Ship A", template=ship_templates["All Defense Tokens"], upgrades=[], player_number=1)
    ship_b = Ship(name="Ship B", template=ship_templates["All Defense Tokens"], upgrades=[], player_number=2)
    world_state.addShip(ship_a, 0)
    world_state.addShip(ship_b, 1)
    pool_colors, pool_faces = ['black'] * 4, ['hit_crit'] * 4
    world_state.setPhase("ship phase", phase_name)
    # Set the front hull zone to 2 shields
    ship_b.get_range('shields')[ArmadaTypes.hull_zones.index('front')] = 2
    # Set the hull to 3 (by assigning damage to reduce the remaining hull to 3)
    ship_b.set('damage', ship_b.get('hull') - 3)
    attack = AttackState('short', ship_a, 'left', ship_b, 'front', pool_colors, pool_faces)
    world_state.updateAttack(attack)
    action_encoding = torch.cat((Encodings.encodeWorldState(world_state),
                                 Encodings.encodeAction(world_state.sub_phase, None)))
    state_encoding = Encodings.encodeAttackState(world_state)
    batch[0] = torch.cat(
        (action_encoding.to(target_device), state_encoding.to(target_device)))

    action = [("brace", (ArmadaTypes.green, None))]
    action_encoding = torch.cat((Encodings.encodeWorldState(world_state),
                                 Encodings.encodeAction(world_state.sub_phase, action)))
    state_encoding = Encodings.encodeAttackState(world_state)
    batch[1] = torch.cat(
        (action_encoding.to(target_device), state_encoding.to(target_device)))

    world_state = WorldState()
    world_state.round = 1
    ship_a = Ship(name="Ship A", template=ship_templates["All Defense Tokens"], upgrades=[], player_number=1)
    ship_b = Ship(name="Ship B", template=ship_templates["All Defense Tokens"], upgrades=[], player_number=2)
    world_state.addShip(ship_a, 0)
    world_state.addShip(ship_b, 1)
    pool_colors, pool_faces = ['black'] * 4, ['hit_crit'] * 2 + ['hit'] * 2
    world_state.setPhase("ship phase", phase_name)
    # Set the front hull zone to 2 shields
    ship_b.get_range('shields')[ArmadaTypes.hull_zones.index('front')] = 2
    # Set the hull to 3 (by assigning damage to reduce the remaining hull to 3)
    ship_b.set('damage', ship_b.get('hull') - 3)
    attack = AttackState('short', ship_a, 'left', ship_b, 'front', pool_colors, pool_faces)
    world_state.updateAttack(attack)

    action = [("redirect", (ArmadaTypes.green, [('left', 4)]))]
    action_encoding = torch.cat((Encodings.encodeWorldState(world_state),
                                 Encodings.encodeAction(world_state.sub_phase, action)))
    state_encoding = Encodings.encodeAttackState(world_state)
    batch[2] = torch.cat(
        (action_encoding.to(target_device), state_encoding.to(target_device)))

    round_status = network(batch[:3])
    print("super cool estimated rounds of destructions are {}".format(round_status[:3]))

    # Using no defense token results in destruction, the final round should be less
    assert(round_status[0].item() < round_status[1].item())
    assert(round_status[0].item() < round_status[2].item())
def collect_attack_batches(batch_size, attacks, subphase):
    """A generator to collect training batches from a list of attack logs.

    Collect all of the actions taken during the resolve attack effects stage of a trial and
    associate them with the round number when the trial ended.  Only sample a single action-state
    pair from each trial into the training batch though. This avoids the network simply memorizing
    the output of specific scenarios in the event of fairly unique events (for example a certain
    unlikely dice roll or combination of shields and hull in the defending ship).

    Args:
        batch_size (int)            : Maximum batch size.
        attacks (List[List[tuples]]): Each sublist is all of the states and actions from a sequence.
        subphase (str)              : Name of the subphase where state/action pairs are collected.
    Returns:
        (batch, labels) : Tuple of the batch and labels.
    """
    # TODO FIXME This function can be used as an dataset iterator for a multithreaded dataloader so
    # the batch and labels should not be passed in. Instead this function must create new tensors
    # for each batch.
    world_size = Encodings.calculateWorldStateSize()
    action_size = Encodings.calculateActionSize(subphase)
    attack_size = Encodings.calculateAttackSize()
    input_size = world_size + action_size + attack_size
    batch = torch.zeros(batch_size, input_size)
    labels = torch.zeros(batch_size, 1)

    # Variables for collection
    # collect_state records what we are doing inside of the sample loop
    collect_state = 0
    # The state and (state, action) pairs collected from the current trial
    last_state = None
    state_actions_attacks = []
    # Counter for the training target
    attack_count = 0
    cur_sample = 0
    last_round = 0
    for sequence in attacks:
        # We only collect a single sample per sequence. Collect all of the state action pairs of
        # interest first and then select a single one to use.
        state_action_pairs = []
        for attack in sequence:
            if 'state' == attack[0]:
                last_round = attack[1].round
                if attack[1].sub_phase == subphase:
                    last_state = attack[1]
                    collect_state = 1
                else:
                    # Not in the desired subphase and the attack trial is not complete.
                    # Waiting for the next attack in the trial or for the trial to end.
                    collect_state = 2
            elif 'action' == attack[0] and 1 == collect_state:
                # Collect the actions associated with last_state. The attack count will later be
                # corrected to be the number of total attacks from that state rather than the current
                # attack.
                state_action_pairs.append((last_state, attack[1]))
        # Collect a single sample (as long as one was present)
        if 0 < len(state_action_pairs):
            selected = random.choice(state_action_pairs)
            # The training label is the final round where the ship was destroyed, or 7 if the ship
            # was not destroyed in the 6 game rounds.
            labels[cur_sample] = last_round
            world_size = Encodings.calculateWorldStateSize()
            action_size = Encodings.calculateActionSize(selected[0].sub_phase)
            attack_size = Encodings.calculateAttackSize()
            Encodings.encodeWorldState(world_state=selected[0],
                                       encoding=batch[cur_sample, :world_size])
            Encodings.encodeAction(
                subphase=selected[0].sub_phase,
                action_list=selected[1],
                encoding=batch[cur_sample,
                               world_size:world_size + action_size])
            Encodings.encodeAttackState(world_state=selected[0],
                                        encoding=batch[cur_sample, world_size +
                                                       action_size:])
            cur_sample += 1
        # When a full batch is collected return it immediately.
        if cur_sample == batch_size:
            yield ((batch[:cur_sample], labels[:cur_sample]))
            # Make new buffers to store training data
            batch = torch.zeros(batch_size, input_size)
            labels = torch.zeros(batch_size, 1)
            cur_sample = 0

    # If there are leftover samples that did not fill a batch still return them.
    if 0 < cur_sample:
        yield ((batch[:cur_sample], labels[:cur_sample]))