示例#1
0
    def __init__(self, with_novelty=True):
        # An extra two outputs
        # are added in for the mean and variance estimation.

        nn.Module.__init__(self)
        self.with_novelty = with_novelty
        def_in = Encodings.calculateAttackSize()
        def_out = Encodings.calculateSpendDefenseTokensSize() + 2
        self.models = nn.ModuleDict()
        self.models["def_tokens"] = self.init_fc_params(def_in, def_out)
        self.optimizers = {
            "def_tokens":
            torch.optim.Adam(self.models["def_tokens"].parameters())
        }
        if self.with_novelty:
            # We will also create two more models to use for random distillation.  The first random
            # model will remain static and the second will be learn to predict the outputs of the
            # first. The difference between the two outputs will be used to estimate the novelty of
            # the current state. If the state is new then the second model will not be able to make
            # a good prediction of the first model's outputs.  In other words, the first model
            # projects the inputs into a new latent space. The ability of the second model to
            # predict the projection into the latent space should be correlated to how similar this
            # state is to ones we have previously visited.

            # The novelty network is a clone of the corresponding network
            self.models["def_tokens_novelty"] = nn.ModuleList([
                self.init_fc_params(def_in, def_out), self.models["def_tokens"]
            ])
            self.models["def_tokens_static"] = self.init_fc_params(
                def_in, def_out)
            self.optimizers["def_tokens_novelty"] = torch.optim.Adam(
                self.models["def_tokens_novelty"].parameters())

        self.sm = nn.Softmax()
    def __init__(self, model=None):
        """Initialize the simple agent with a couple of simple state handlers.
        
        Args:
            model (torch.nn.Module or None): If None this agent will pass for all supported states
        """
        handler = {
            "ship phase - attack - resolve attack effects":
            self.resolveAttackEffects,
            "ship phase - attack - spend defense tokens":
            self.spendDefenseTokens
        }
        super(LearningAgent, self).__init__(handler)
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        self.model = model
        if None != self.model:
            self.model = self.model.to(self.device)
            self.model.eval()

        self.attack_enc_size = Encodings.calculateAttackSize()
        self.memory = []
        self.remembering = False
示例#3
0
    def __init__(self):
        # There is a different model for each phase.
        # Each model takes as input the world state and the attack state and output is the action

        world_size = Encodings.calculateWorldStateSize()
        attack_size = Encodings.calculateAttackSize()

        nn.Module.__init__(self)
        self.models = nn.ModuleDict()
        self.optimizers = {}

        # TODO FIXME Loop through all of the possible phases instead of just these two
        for phase_name in [
                "attack - resolve attack effects",
                "attack - spend defense tokens"
        ]:
            action_size = Encodings.calculateActionSize(phase_name)
            self.models[phase_name] = self.init_fc_params(
                world_size + attack_size, action_size)
            self.optimizers[phase_name] = torch.optim.Adam(
                self.models[phase_name].parameters())

        self.sm = nn.Softmax()
def test_defense_tokens_model(spend_defense_tokens_model):
    """Test basic network learning loop.

    Test lifetime predictions during the spend defense tokens phase.
    """
    network, errors, eval_errors = spend_defense_tokens_model
    network.eval()
    phase_name = "attack - spend defense tokens"
    world_size = Encodings.calculateWorldStateSize()
    action_size = Encodings.calculateActionSize(phase_name)
    attack_size = Encodings.calculateAttackSize()
    input_size = world_size + action_size + attack_size

    target_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    batch_size = 32
    batch = torch.Tensor(batch_size, input_size).to(target_device)

    print("Eval errors for B were {}".format(eval_errors))
    #print("First and last errors in b are {} and {}".format(eval_errors[0], eval_errors[-1]))

    assert eval_errors[0] > eval_errors[-1]

    # Let's examine predictions for different dice pools and spent defense tokens.
    # Go through the following scenarios:
    # 1.1 An attack with more than enough damage to destroy the ship
    # 1.2 The same attack but a brace that would prevent destruction
    # 1.3 The same attack but a redirect that would prevent destruction
    # Result: 1.1 should have lower lifetime than 1.2 and 1.3
    # 2.1 An attack that can barely destroy the ship
    # 2.2 An attack that barely will not destroy the ship
    # Result: 2.1 should have lower lifetime than 2.2.
    # Ideally 1.1 and 2.1 would predict the current round.
    world_state = WorldState()
    world_state.round = 1
    ship_a = Ship(name="Ship A", template=ship_templates["All Defense Tokens"], upgrades=[], player_number=1)
    ship_b = Ship(name="Ship B", template=ship_templates["All Defense Tokens"], upgrades=[], player_number=2)
    world_state.addShip(ship_a, 0)
    world_state.addShip(ship_b, 1)
    pool_colors, pool_faces = ['black'] * 4, ['hit_crit'] * 4
    world_state.setPhase("ship phase", phase_name)
    # Set the front hull zone to 2 shields
    ship_b.get_range('shields')[ArmadaTypes.hull_zones.index('front')] = 2
    # Set the hull to 3 (by assigning damage to reduce the remaining hull to 3)
    ship_b.set('damage', ship_b.get('hull') - 3)
    attack = AttackState('short', ship_a, 'left', ship_b, 'front', pool_colors, pool_faces)
    world_state.updateAttack(attack)
    action_encoding = torch.cat((Encodings.encodeWorldState(world_state),
                                 Encodings.encodeAction(world_state.sub_phase, None)))
    state_encoding = Encodings.encodeAttackState(world_state)
    batch[0] = torch.cat(
        (action_encoding.to(target_device), state_encoding.to(target_device)))

    action = [("brace", (ArmadaTypes.green, None))]
    action_encoding = torch.cat((Encodings.encodeWorldState(world_state),
                                 Encodings.encodeAction(world_state.sub_phase, action)))
    state_encoding = Encodings.encodeAttackState(world_state)
    batch[1] = torch.cat(
        (action_encoding.to(target_device), state_encoding.to(target_device)))

    world_state = WorldState()
    world_state.round = 1
    ship_a = Ship(name="Ship A", template=ship_templates["All Defense Tokens"], upgrades=[], player_number=1)
    ship_b = Ship(name="Ship B", template=ship_templates["All Defense Tokens"], upgrades=[], player_number=2)
    world_state.addShip(ship_a, 0)
    world_state.addShip(ship_b, 1)
    pool_colors, pool_faces = ['black'] * 4, ['hit_crit'] * 2 + ['hit'] * 2
    world_state.setPhase("ship phase", phase_name)
    # Set the front hull zone to 2 shields
    ship_b.get_range('shields')[ArmadaTypes.hull_zones.index('front')] = 2
    # Set the hull to 3 (by assigning damage to reduce the remaining hull to 3)
    ship_b.set('damage', ship_b.get('hull') - 3)
    attack = AttackState('short', ship_a, 'left', ship_b, 'front', pool_colors, pool_faces)
    world_state.updateAttack(attack)

    action = [("redirect", (ArmadaTypes.green, [('left', 4)]))]
    action_encoding = torch.cat((Encodings.encodeWorldState(world_state),
                                 Encodings.encodeAction(world_state.sub_phase, action)))
    state_encoding = Encodings.encodeAttackState(world_state)
    batch[2] = torch.cat(
        (action_encoding.to(target_device), state_encoding.to(target_device)))

    round_status = network(batch[:3])
    print("super cool estimated rounds of destructions are {}".format(round_status[:3]))

    # Using no defense token results in destruction, the final round should be less
    assert(round_status[0].item() < round_status[1].item())
    assert(round_status[0].item() < round_status[2].item())
def test_resolve_attack_effects_model(resolve_attack_effects_model):
    """Test basic network learning loop.

    Test lifetime predictions during the resolve attack effects phase.
    """
    network, errors, eval_errors = resolve_attack_effects_model
    network.eval()
    phase_name = "attack - resolve attack effects"
    world_size = Encodings.calculateWorldStateSize()
    action_size = Encodings.calculateActionSize(phase_name)
    attack_size = Encodings.calculateAttackSize()
    input_size = world_size + action_size + attack_size

    # First verify that errors decreased during training.
    # print("Errors for A were {}".format(errors))
    print("Eval errors for A were {}".format(eval_errors))
    assert eval_errors[0] > eval_errors[-1]

    # Let's examine predictions for different ranges and hull zones.
    target_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    batch_size = 4
    batch = torch.Tensor(batch_size, input_size).to(target_device)

    # Let's examine predictions for different dice pools and spent defense tokens.
    # Go through the following scenarios:
    # 1.1 An attack upon a ship with only 1 hull remaining
    # 1.2 The same dice pool but on a ship with full hull
    # 1.3 A dice pool with only blank dice
    # 1.4 A dice pool with only blanks when attacking at long range.

    # Create a state from resolve attack effects and an empty action.
    world_state = WorldState()
    world_state.round = 1
    ship_a = Ship(name="Ship A", template=ship_templates["All Defense Tokens"], upgrades=[], player_number=1)
    ship_b = Ship(name="Ship B", template=ship_templates["All Defense Tokens"], upgrades=[], player_number=2)
    world_state.addShip(ship_a, 0)
    world_state.addShip(ship_b, 1)
    pool_colors, pool_faces = ['black'] * 4, ['hit_crit'] * 4
    world_state.setPhase("ship phase", "attack - resolve attack effects")
    ship_b.set('damage', ship_b.get('hull') - 1)
    attack = AttackState('short', ship_a, 'left', ship_b, 'front', pool_colors, pool_faces)
    world_state.updateAttack(attack)
    action_encoding = torch.cat((Encodings.encodeWorldState(world_state),
                                 Encodings.encodeAction(world_state.sub_phase, None)))
    state_encoding = Encodings.encodeAttackState(world_state)
    batch[0] = torch.cat(
        (action_encoding.to(target_device), state_encoding.to(target_device)))

    # Same dice pool but the defender has full hull
    ship_b.set('damage', 0)
    attack = AttackState('short', ship_a, 'left', ship_b, 'front', pool_colors, pool_faces)
    world_state.updateAttack(attack)
    action_encoding = torch.cat((Encodings.encodeWorldState(world_state),
                                 Encodings.encodeAction(world_state.sub_phase, None)))
    state_encoding = Encodings.encodeAttackState(world_state)
    batch[1] = torch.cat(
        (action_encoding.to(target_device), state_encoding.to(target_device)))

    # Full hull and all blanks
    pool_colors, pool_faces = ['black'] * 4, ['blank'] * 4
    world_state.setPhase("ship phase", "attack - resolve attack effects")
    attack = AttackState('short', ship_a, 'left', ship_b, 'front', pool_colors, pool_faces)
    world_state.updateAttack(attack)
    state_encoding = Encodings.encodeAttackState(world_state)
    batch[2] = torch.cat(
        (action_encoding.to(target_device), state_encoding.to(target_device)))

    # Full hull, all blanks, firing at red range
    pool_colors, pool_faces = ['red'] * 2, ['blank'] * 2
    world_state.setPhase("ship phase", "attack - resolve attack effects")
    attack = AttackState('long', ship_a, 'left', ship_b, 'front', pool_colors, pool_faces)
    world_state.updateAttack(attack)
    state_encoding = Encodings.encodeAttackState(world_state)
    batch[3] = torch.cat(
        (action_encoding.to(target_device), state_encoding.to(target_device)))

    lifetime_out = network(batch)
    print("super cool attack effects round estimates are {}".format(lifetime_out))

    # The lifetimes should go up sequentially with the above scenarios.
    # However if the ship won't be destroyed the NN can't make an accurate relative number so be
    # lenient once lifetimes go above round 6. The first scenario should result in destruction
    # however.
    assert(lifetime_out[0].item() < 6)
    for i in range(batch.size(0) - 1):
        assert(lifetime_out[i].item() < lifetime_out[i+1].item() or 
                (lifetime_out[i].item() > 6. and lifetime_out[i+1].item() > 6.))
def update_lifetime_network(lifenet, batch, labels, optimizer, eval_only=False):
    """Do a forward and backward pass through the given lifetime network.

    Args:
        lifenet (torch.nn.Module): Trainable torch model
        batch (torch.tensor)     : Training batch
        labels (torch.tensor)    : Training labels
        optimizer (torch.nn.Optimizer) : Optimizer for lifenet parameters.
        eval_only (bool)         : Only evaluate, don't update parameters.
    Returns:
        batch error              : Average absolute error for this batch
    """
    if eval_only:
        lifenet.eval()
    # Forward through the prediction network
    prediction = lifenet.forward(batch)
    # Loss is the lifetime prediction error
    # The output cannot be negative, run through a ReLU to clean that up
    #f = torch.nn.ReLU()
    #epsilon = 0.001
    #normed_predictions = f(prediction[0]) + epsilon
    with torch.no_grad():
        error = (prediction - labels).abs().mean().item()

    with torch.no_grad():
        errors = (prediction - labels).abs()
        for i in range(errors.size(0)):
            # Debug on crazy errors or nan values.
            if errors[i] > 1000 or errors[i] != errors[i]:
                # This is messy debugging code, but sometimes a test may fail after bug
                # introductions so it is helpful to leave this in to speed up debugging.
                world_size = Encodings.calculateWorldStateSize()
                world_encoding = batch[i,:world_size]
                phase_name = ArmadaPhases.main_phases[int(world_encoding[1].item())]
                sub_phase_name = ArmadaPhases.sub_phases[phase_name][int(world_encoding[2].item())]
                print(f"Error {i} is {errors[i]}")
                print(f"\tRound {world_encoding[0].item()}")
                print(f"\tSubphase {sub_phase_name}")
                action_size = Encodings.calculateActionSize(sub_phase_name)
                attack_size = Encodings.calculateAttackSize()
                action_encoding = batch[i,world_size:world_size + action_size]
                attack_state_encoding = batch[i,world_size + action_size:]
                if "attack - resolve attack effects" == sub_phase_name:
                    print(f"\tattack effect encoding is {action_encoding}")
                elif "attack - spend defense tokens" == sub_phase_name:
                    print(f"\tspend defense token encoding is {action_encoding}")
                else:
                    print("Cannot print information about {}".format(sub_phase_name))
                defender = Ship(name="Defender", player_number=1,
                                encoding=attack_state_encoding[:Ship.encodeSize()])
                attacker = Ship(name="Attacker", player_number=1,
                                encoding=attack_state_encoding[Ship.encodeSize():2 * Ship.encodeSize()])
                # print(f"\tAttack state encoding is {attack_state_encoding}")
                print("\tAttacker is {}".format(attacker))
                print("\tDefender is {}".format(defender))
                # TODO FIXME Enough dice in a pool seems to end a ship, but unless the pools are
                # incredibly large this doesn't seem to be happening. Damage does not seem to be
                # accumlating between rounds.
                die_offset = Encodings.getAttackDiceOffset()
                dice_encoding = attack_state_encoding[die_offset:die_offset + Encodings.dieEncodingSize()]
                print("\tDice are {}".format(dice_encoding))
                print(f"\tLabel is {labels[i]}")


    # Normal distribution:
    #normal = torch.distributions.normal.Normal(predict_tensor[0], predict_tensor[1])
    #loss = -normal.log_prob(labels)
    # Poisson distribution (works well)
    #poisson = torch.distributions.poisson.Poisson(normed_predictions)
    #loss = -poisson.log_prob(labels)
    # Plain old MSE error
    #loss_fn = torch.nn.MSELoss()
    #loss = loss_fn(prediction, labels)
    # Absolute error
    loss_fn = torch.nn.L1Loss()
    loss = loss_fn(prediction, labels)

    if not eval_only:
        optimizer.zero_grad()
        loss.sum().backward()
        optimizer.step()
    else:
        lifenet.train()
    return error
def test_policy_learning(spend_defense_tokens_model, resolve_attack_effects_model):
    """Train a model to produce a better than random choice policy for defense tokens.

    The spend_defense_tokens_model will be used to determine the quality of this network's output.
    There will not be an update step as in reinforcement learning, this is just testing the
    mechanism.

    Returns:
        (nn.module, nn.module,): The 'resolve attack effects' and 'spend defense tokens' models.
    """
    def_tokens_model, errors, eval_errors = spend_defense_tokens_model
    def_tokens_model.eval()
    res_attack_model, errors, eval_errors = resolve_attack_effects_model
    res_attack_model.eval()
    prediction_models = {
        "attack - spend defense tokens": def_tokens_model.eval(),
        "attack - resolve attack effects": res_attack_model.eval()
    }
    # Do the training. Use the prediction model lifetime to create the loss target. The loss
    # will be the difference between the max possible round and the predicted round.
    # For the defense token model the higher round is better, for the attack effect model a
    # lower round is better.
    loss_fn = {
        "attack - spend defense tokens": lambda predictions: 7.0 - predictions,
        "attack - resolve attack effects": lambda predictions: predictions - 1.
    }

    # Generate a new learning model
    learning_agent = LearningAgent(SeparatePhaseModel())
    # TODO FIXME The learning agent doesn't really have a training mode
    learning_agent.model.train()
    random_agent = RandomAgent()

    training_ships = ["All Defense Tokens", "All Defense Tokens",
                      "Imperial II-class Star Destroyer", "MC80 Command Cruiser",
                      "Assault Frigate Mark II A", "No Shield Ship", "One Shield Ship",
                      "Mega Die Ship"]
    defenders = []
    attackers = []
    for name in training_ships:
        attackers.append(Ship(name=name, template=ship_templates[name],
                              upgrades=[], player_number=1, device='cpu'))
    for name in training_ships:
        defenders.append(Ship(name=name, template=ship_templates[name],
                              upgrades=[], player_number=2, device='cpu'))

    batch_size = 32
    # Remember the training loss values to test for improvement
    losses = {}
    for subphase in ["attack - spend defense tokens", "attack - resolve attack effects"]:
        losses[subphase] = []
    # This gets samples to use for training. We will use a random agent to generate the states.
    # In reinforcement learning the agent would alternate between random actions and actions
    # from the learning agent to balance exploration of the state space with exploitation of
    # learned behavior.
    samples = get_n_examples(
        n_examples=1000, ship_a=attackers, ship_b=defenders, agent=random_agent)

    # Get a batch for each subphase
    for subphase in ["attack - spend defense tokens", "attack - resolve attack effects"]:
        # The learning_agent will take in the world state and attack state and will produce a new
        # action encoding. Then the prediction network will take in a new tuple with this action and
        # predict a final round. The difference between the random action and the final round with
        # the network's action is the reward.
        world_size = Encodings.calculateWorldStateSize()
        action_size = Encodings.calculateActionSize(subphase)
        attack_size = Encodings.calculateAttackSize()
        for batch, lifetimes in collect_attack_batches(batch_size=batch_size, attacks=samples, subphase=subphase):
            batch = batch.cuda()
            # The learning agent takes in the world state along with the action as the input tensor.
            new_batch = torch.cat((batch[:,:world_size], batch[:,world_size + action_size:]), dim=1).cuda()
            # TODO FIXME Just make a forward function that takes in a phase name
            new_actions = learning_agent.model.models[subphase](new_batch)
            new_action_state = torch.cat((batch[:,:world_size], new_actions, batch[:,world_size + action_size:]), dim=1)
            action_result = prediction_models[subphase](new_action_state)
            loss = loss_fn[subphase](action_result)
            learning_agent.model.get_optimizer(subphase).zero_grad()
            loss.sum().backward()
            learning_agent.model.get_optimizer(subphase).step()
            with torch.no_grad():
                losses[subphase].append(loss.mean().item())
            # In reinforcement learning there would also be a phase where the prediction models are
            # updated.
    for subphase in ["attack - spend defense tokens", "attack - resolve attack effects"]:
        print(f"losses for {subphase} start with {losses[subphase][0:5]} and end with {losses[subphase][-5:]}")
        assert losses[subphase][-1] < losses[subphase][0]
    # TODO FIXME HERE See if the policy networks produce better results than random actions
    learning_agent.model.models.eval()
def spend_defense_tokens_model(create_spend_defense_tokens_dataset):
    """Train some basic lifetime prediction models.

    Create simple networks that predict defending ship lifetimes during the 'spend defese tokens'
    phase and during the 'resolve attack effects' phase.

    Returns:
        (nn.module, [float], [float]): The model, training errors per epoch, and eval errors per
                                       epoch.
    """
    attacks, eval_attacks = create_spend_defense_tokens_dataset

    target_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Lifetime network B predicts lifetimes given a state and action pair in the 'spend defense
    # tokens' subphase.
    phase_name = "attack - spend defense tokens"
    world_size = Encodings.calculateWorldStateSize()
    action_size = Encodings.calculateActionSize(phase_name)
    attack_size = Encodings.calculateAttackSize()
    input_size = world_size + action_size + attack_size
    network = torch.nn.Sequential(
        torch.nn.BatchNorm1d(input_size),
        torch.nn.Linear(input_size, 2 * input_size),
        torch.nn.ELU(),
        torch.nn.Linear(2 * input_size, 4 * input_size),
        torch.nn.ELU(),
        torch.nn.Linear(4 * input_size, 2 * input_size),
        torch.nn.ELU(),
        torch.nn.Linear(2 * input_size, 1))
    network.to(target_device)
    optimizer = torch.optim.Adam(network.parameters(), lr=0.0005)
    batch_size = 32
    batch = torch.Tensor(batch_size, input_size).to(target_device)

    # Keep track of the errors for the purpose of this test
    errors = []
    eval_errors = []

    # Evaluate before training and every epoch
    for batch in eval_attacks:
        eval_data = batch[0].to(target_device)
        eval_labels = batch[1].to(target_device)
        eval_errors.append(update_lifetime_network(network, eval_data,
                                                   eval_labels, None, True))
    # Train with all of the data for 10 epochs
    for epoch in range(10):
        print("Training resolve_spend_defense_tokens_model epoch {}".format(epoch))
        train_batches = 0
        for batch in attacks:
            train_data = batch[0].to(target_device)
            train_labels = batch[1].to(target_device)
            errors.append(update_lifetime_network(network, train_data,
                                                  train_labels, optimizer))
            train_batches += 1
        print("Finished epoch with {} batches.".format(train_batches))
        # Evaluate every epoch
        for batch in eval_attacks:
            eval_data = batch[0].to(target_device)
            eval_labels = batch[1].to(target_device)
            eval_errors.append(update_lifetime_network(network, eval_data,
                                                       eval_labels, None, True))

    return network, errors, eval_errors
def resolve_attack_effects_model(create_attack_effects_dataset):
    """Train some basic lifetime prediction models.

    Create a simple network that predict defending ship lifetimes during the 'resolve attack
    effects' phase. Also creates logs of training and evaluation loss.

    Returns:
        (nn.module, [float], [float]): The model, training errors per epoch, and eval errors per
                                       epoch.
    """
    attacks, eval_attacks = create_attack_effects_dataset

    target_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Lifetime network A predicts lifetimes given a state and action pair in the 'resolve attack
    # effects' subphase.
    phase_name = "attack - resolve attack effects"
    world_size = Encodings.calculateWorldStateSize()
    action_size = Encodings.calculateActionSize(phase_name)
    attack_size = Encodings.calculateAttackSize()
    input_size = world_size + action_size + attack_size
    # The network size was made large enough that training plateaued to a stable value
    # If the network is too large it has a tendency to start fitting to specific cases.
    # Batchnorm doesn't seem to help out much with this network and task.
    network = torch.nn.Sequential(
        torch.nn.BatchNorm1d(input_size),
        torch.nn.Linear(input_size, 2 * input_size),
        torch.nn.ELU(),
        torch.nn.Linear(2 * input_size, 4 * input_size),
        torch.nn.ELU(),
        #torch.nn.Dropout(),
        #torch.nn.BatchNorm1d(4 * input_size),
        torch.nn.Linear(4 * input_size, 2 * input_size),
        torch.nn.ELU(),
        torch.nn.Linear(2 * input_size, 1))
    network.to(target_device)
    # Higher learning rates lead to a lot of instability in the training.
    optimizer = torch.optim.Adam(network.parameters(), lr=0.0005)

    # Keep track of the errors for the purpose of this test
    errors = []
    eval_errors = []

    # Evaluate before training and every epoch
    for batch in eval_attacks:
        eval_data = batch[0].to(target_device)
        eval_labels = batch[1].to(target_device)
        eval_errors.append(update_lifetime_network(network, eval_data,
                                                   eval_labels, None, True))
    # Train with all of the data for 10 epochs
    for epoch in range(10):
        print("Training resolve_attack_effects_model epoch {}".format(epoch))
        epoch_samples = 0
        train_batches = 0
        for batch in attacks:
            train_data = batch[0].to(target_device)
            train_labels = batch[1].to(target_device)
            errors.append(update_lifetime_network(network, train_data,
                                                  train_labels, optimizer))
            epoch_samples += batch[0].size(0)
            train_batches += 1
        print("Finished epoch with {} batches.".format(train_batches))

        # Evaluate every epoch
        for batch in eval_attacks:
            eval_data = batch[0].to(target_device)
            eval_labels = batch[1].to(target_device)
            eval_errors.append(update_lifetime_network(network, eval_data,
                                                       eval_labels, None, True))

    return network, errors, eval_errors
def collect_attack_batches(batch_size, attacks, subphase):
    """A generator to collect training batches from a list of attack logs.

    Collect all of the actions taken during the resolve attack effects stage of a trial and
    associate them with the round number when the trial ended.  Only sample a single action-state
    pair from each trial into the training batch though. This avoids the network simply memorizing
    the output of specific scenarios in the event of fairly unique events (for example a certain
    unlikely dice roll or combination of shields and hull in the defending ship).

    Args:
        batch_size (int)            : Maximum batch size.
        attacks (List[List[tuples]]): Each sublist is all of the states and actions from a sequence.
        subphase (str)              : Name of the subphase where state/action pairs are collected.
    Returns:
        (batch, labels) : Tuple of the batch and labels.
    """
    # TODO FIXME This function can be used as an dataset iterator for a multithreaded dataloader so
    # the batch and labels should not be passed in. Instead this function must create new tensors
    # for each batch.
    world_size = Encodings.calculateWorldStateSize()
    action_size = Encodings.calculateActionSize(subphase)
    attack_size = Encodings.calculateAttackSize()
    input_size = world_size + action_size + attack_size
    batch = torch.zeros(batch_size, input_size)
    labels = torch.zeros(batch_size, 1)

    # Variables for collection
    # collect_state records what we are doing inside of the sample loop
    collect_state = 0
    # The state and (state, action) pairs collected from the current trial
    last_state = None
    state_actions_attacks = []
    # Counter for the training target
    attack_count = 0
    cur_sample = 0
    last_round = 0
    for sequence in attacks:
        # We only collect a single sample per sequence. Collect all of the state action pairs of
        # interest first and then select a single one to use.
        state_action_pairs = []
        for attack in sequence:
            if 'state' == attack[0]:
                last_round = attack[1].round
                if attack[1].sub_phase == subphase:
                    last_state = attack[1]
                    collect_state = 1
                else:
                    # Not in the desired subphase and the attack trial is not complete.
                    # Waiting for the next attack in the trial or for the trial to end.
                    collect_state = 2
            elif 'action' == attack[0] and 1 == collect_state:
                # Collect the actions associated with last_state. The attack count will later be
                # corrected to be the number of total attacks from that state rather than the current
                # attack.
                state_action_pairs.append((last_state, attack[1]))
        # Collect a single sample (as long as one was present)
        if 0 < len(state_action_pairs):
            selected = random.choice(state_action_pairs)
            # The training label is the final round where the ship was destroyed, or 7 if the ship
            # was not destroyed in the 6 game rounds.
            labels[cur_sample] = last_round
            world_size = Encodings.calculateWorldStateSize()
            action_size = Encodings.calculateActionSize(selected[0].sub_phase)
            attack_size = Encodings.calculateAttackSize()
            Encodings.encodeWorldState(world_state=selected[0],
                                       encoding=batch[cur_sample, :world_size])
            Encodings.encodeAction(
                subphase=selected[0].sub_phase,
                action_list=selected[1],
                encoding=batch[cur_sample,
                               world_size:world_size + action_size])
            Encodings.encodeAttackState(world_state=selected[0],
                                        encoding=batch[cur_sample, world_size +
                                                       action_size:])
            cur_sample += 1
        # When a full batch is collected return it immediately.
        if cur_sample == batch_size:
            yield ((batch[:cur_sample], labels[:cur_sample]))
            # Make new buffers to store training data
            batch = torch.zeros(batch_size, input_size)
            labels = torch.zeros(batch_size, 1)
            cur_sample = 0

    # If there are leftover samples that did not fill a batch still return them.
    if 0 < cur_sample:
        yield ((batch[:cur_sample], labels[:cur_sample]))