コード例 #1
0
import torch
from torch import nn
import numpy as np
import copy
import time
import os
from tqdm import tqdm
from torch_truncnorm.TruncatedNormal import TruncatedNormal
import random
import asyncio
from derk_PPO_LSTM import lstm_agent
from reward_functions import *

device = "cuda:0"
ITERATIONS = 1000000
agent = lstm_agent(512, device)

arm_weapons = [
    "Talons", "BloodClaws", "Cleavers", "Cripplers", "Pistol", "Magnum",
    "Blaster"
]
misc_weapons = [
    "FrogLegs", "IronBubblegum", "HeliumBubblegum", "Shell", "Trombone"
]
tail_weapons = ["HealingGland", "VampireGland", "ParalyzingDart"]

n_arenas = 80

league_size = 10
league = [lstm_agent(512, device) for i in range(league_size)]
コード例 #2
0
device = "cuda:0"
ITERATIONS = 1000000
root_dir = "checkpoints/TEST_LEAGUE_AGENTS"

league = []
member_names = []

k = 30  #ELO weight parameter
k_decay = 0.99

count = 0
for root, dirs, files in os.walk(root_dir):
    for name in files:  # Load in the agents stored at root_dir
        if "best" in name:
            temp = lstm_agent(1024, device)
            temp.load_state_dict(torch.load(os.path.join(root, name)))
            temp.id = count
            temp.ELO = 1000
            temp.name = name
            count += 1
            league.append(temp)

random.shuffle(league)
league_size = len(
    league
)  # Number of policies.  Must be even because we don't want byes or anything like that.
assert league_size % 2 == 0, "Number of policies in the TEST_LEAGUE_AGENTS folder must be even"

arm_weapons = [["Pistol", "Magnum", "Blaster"],
               ["Talons", "BloodClaws", "Cleavers", "Cripplers"],
コード例 #3
0
ファイル: derk_train.py プロジェクト: Richie78321/rl-moba
import time
import torch
from torch import nn
import numpy as np
import copy
import time
import os
from tqdm import tqdm
from torch_truncnorm.TruncatedNormal import TruncatedNormal
import random
from derk_PPO_LSTM import lstm_agent
from reward_functions import *

device = "cuda:0"
ITERATIONS = 1000000
agent = lstm_agent(512, device)

arm_weapons = [
    "Talons", "BloodClaws", "Cleavers", "Cripplers", "Pistol", "Magnum",
    "Blaster"
]
misc_weapons = [
    "FrogLegs", "IronBubblegum", "HeliumBubblegum", "Shell", "Trombone"
]
tail_weapons = ["HealingGland", "VampireGland", "ParalyzingDart"]

n_arenas = 80
random_configs = [{
    "slots": [
        random.choice(arm_weapons),
        random.choice(misc_weapons),
コード例 #4
0
    'continuous_entropy_coeff': perturb_explore,
    #'continuous_coeff': perturb_explore,
    'value_coeff': perturb_explore,
    'minibatch_size': minibatch_discrete_perturb_explore,
    "lstm_fragment_length": fragment_length_perturb_explore,
}

# Initialize population with uniformly distributed hyperparameters.
population = [
    lstm_agent(
        1024,
        device,
        hyperparams={
            'learning_rate': 10**np.random.uniform(-5, -3),
            'discrete_entropy_coeff': 10**np.random.uniform(-6, -4),
            'continuous_entropy_coeff': 10**np.random.uniform(-6, -4),
            #'continuous_coeff': 10 ** np.random.uniform(0.8, 2),
            'value_coeff': 10**np.random.uniform(-1, 0.3),
            'minibatch_size': int(10**np.random.uniform(2.4, 3.6)),
            "lstm_fragment_length":
            int(random.choice(fragment_length_choices)),
        }) for i in range(population_size)
]

print(population[0].get_hyperparams())

# Record the last PBT update
last_PBT_update = [0] * len(population)

model_checkpoint_schedule = [2 * int(i**1.6) for i in range(1000)]
save_folder = "checkpoints/PPO-LSTM-PBT" + str(time.time())