import torch import torch.nn as nn import torch.nn.functional as F import math import sys import random #create seed from utils.tool import read_seed torch.manual_seed(read_seed()) class InitLinear(nn.Module): def __init__(self, in_features, out_features): super(InitLinear, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = nn.Parameter(torch.empty(out_features, in_features)) self.bias = nn.Parameter(torch.empty(out_features)) self.reset_parameters() def reset_parameters(self): mu_range = 1.0 / math.sqrt(self.in_features) self.weight.data.uniform_(-mu_range, mu_range) self.bias.data.uniform_(-mu_range, mu_range) def forward(self, inp): return F.linear(inp, self.weight, self.bias)
import numpy as np import torch import torch.nn as nn import torch.optim as optim from collections import deque import random import time import sys from Models.Networks.SimpleNetwork import Network from utils.tool import read_config, read_seed np.random.seed(read_seed()) config = read_config("./Configs/configD2QN.yml") N_ACTIONS = config["n_actions"] GAMMA = config["gamma"] EPSILON_DECAY = config["epsilon_decay"] EPSILON_MIN = config["epsilon_min"] BATCH_SIZE = config["batch_size"] MIN_REPLAY_SIZE = config["min_replay_size"] TARGET_UPDATE = config["target_update"] SAMPLING_TIME = config["sampling_time"] NAME = config["model_name"] # Deep Q Network agent class Agent: def __init__(self):
parser = argparse.ArgumentParser( description='choose your model') parser.add_argument('-model', type=str, default='DQN', help='DQN/D2QN/D3QN/Noisy/PER/ApeX') parser.add_argument('-seed', type=int, default=-1, help='save/init') args = parser.parse_args() #seeds if args.seed == -1: config = read_config("./Configs/config" + args.model + ".yml") write_seed(config["seed"]) else: write_seed(args.seed) print("seed:", read_seed()) #Models if args.model == "ApeX": runApeX() elif args.model == "DQN": env = EnvCarla.CarEnv(0) agent = ModelDQN.Agent() TrainModel.run(config, agent, env) elif args.model == "D2QN": env = EnvCarla.CarEnv(0) agent = ModelD2QN.Agent() TrainModel.run(config, agent, env) elif args.model == "D3QN":