コード例 #1
0
ファイル: main.py プロジェクト: BenediktKersjes/rl_2048
    SEED = 1337
    LR = 0.0001
    START_STEP = 0
    GAMMA = 0.9
    EPS_START = 0.9
    EPS_DECAY = 200000
    NUM_STEPS = 20
    MAX_GRAD_NORM = 50
    USE_BIG_INPUT = False

    args = WorkerArguments(SEED, LR, START_STEP, GAMMA, EPS_START, EPS_DECAY, NUM_STEPS, MAX_GRAD_NORM, USE_BIG_INPUT)

    torch.manual_seed(SEED)

    shared_model = DQN(16 if USE_BIG_INPUT else 1).float()
    shared_model.share_memory()
    if LOAD_FILE != '':
        shared_model.load_state_dict(torch.load('./trained_models/' + LOAD_FILE + '.pth')['model'])

    target_model = DQN(16 if USE_BIG_INPUT else 1).float()
    target_model.share_memory()
    target_model.load_state_dict(shared_model.state_dict())

    optimizer = SharedAdam(shared_model.parameters(), lr=LR)
    optimizer.share_memory()
    if LOAD_FILE != '':
        optimizer.load_state_dict(torch.load('./trained_models/' + LOAD_FILE + '.pth')['optimizer'])

    processes = []

    with mp.Manager() as manager:
コード例 #2
0
class Agent():
    def __init__(self, args, env):
        self.action_space = env.action_space()
        self.atoms = args.atoms
        self.Vmin = args.V_min
        self.Vmax = args.V_max
        self.support = torch.linspace(args.V_min, args.V_max,
                                      args.atoms)  # Support (range) of z
        self.delta_z = (args.V_max - args.V_min) / (args.atoms - 1)
        self.batch_size = args.batch_size
        self.n = args.multi_step
        self.discount = args.discount

        self.online_net = DQN(args, self.action_space)
        if args.model and os.path.isfile(args.model):
            self.online_net.load_state_dict(torch.load(args.model))
        self.online_net.train()

        self.target_net = DQN(args, self.action_space)
        self.update_target_net()
        self.target_net.train()
        for param in self.target_net.parameters():
            param.requires_grad = False

        self.optimiser = optim.Adam(self.online_net.parameters(),
                                    lr=args.lr,
                                    eps=args.adam_eps)
        if args.cuda:
            self.online_net.cuda()
            self.target_net.cuda()
            self.support = self.support.cuda()

    # Resets noisy weights in all linear layers (of online net only)
    def reset_noise(self):
        self.online_net.reset_noise()

    # Acts based on single state (no batch)
    def act(self, state):
        return (self.online_net(state.unsqueeze(0)).data *
                self.support).sum(2).max(1)[1][0]

    # Acts with an ε-greedy policy
    def act_e_greedy(self, state, epsilon=0.001):
        return random.randrange(
            self.action_space) if random.random() < epsilon else self.act(
                state)

    def learn(self, mem):
        # Sample transitions
        idxs, states, actions, returns, next_states, nonterminals, weights = mem.sample(
            self.batch_size)

        # Calculate current state probabilities
        self.online_net.reset_noise()  # Sample new noise for online network
        ps = self.online_net(states)  # Probabilities p(s_t, ·; θonline)
        ps_a = ps[range(self.batch_size), actions]  # p(s_t, a_t; θonline)

        # Calculate nth next state probabilities
        self.online_net.reset_noise()  # Sample new noise for action selection
        pns = self.online_net(
            next_states).data  # Probabilities p(s_t+n, ·; θonline)
        dns = self.support.expand_as(
            pns) * pns  # Distribution d_t+n = (z, p(s_t+n, ·; θonline))
        argmax_indices_ns = dns.sum(2).max(
            1
        )[1]  # Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))]
        self.target_net.reset_noise()  # Sample new target net noise
        pns = self.target_net(
            next_states).data  # Probabilities p(s_t+n, ·; θtarget)
        pns_a = pns[range(
            self.batch_size
        ), argmax_indices_ns]  # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget)
        if nonterminals.min() == 0:
            terminals = (1 - nonterminals.squeeze()).nonzero().squeeze()
            pns_a[terminals] = 1 / self.atoms  # Divide probability equally

        # Compute Tz (Bellman operator T applied to z)
        Tz = returns.unsqueeze(1) + nonterminals * (
            self.discount**self.n) * self.support.unsqueeze(
                0)  # Tz = R^n + (γ^n)z (accounting for terminal states)
        Tz = Tz.clamp(min=self.Vmin,
                      max=self.Vmax)  # Clamp between supported values
        # Compute L2 projection of Tz onto fixed support z
        b = (Tz - self.Vmin) / self.delta_z  # b = (Tz - Vmin) / Δz
        l, u = b.floor().long(), b.ceil().long()

        # Distribute probability of Tz
        m = states.data.new(self.batch_size, self.atoms).zero_()
        offset = torch.linspace(0, ((self.batch_size - 1) * self.atoms),
                                self.batch_size).long().unsqueeze(1).expand(
                                    self.batch_size,
                                    self.atoms).type_as(actions)
        m.view(-1).index_add_(
            0, (l + offset).view(-1),
            (pns_a *
             (u.float() - b)).view(-1))  # m_l = m_l + p(s_t+n, a*)(u - b)
        m.view(-1).index_add_(
            0, (u + offset).view(-1),
            (pns_a *
             (b - l.float())).view(-1))  # m_u = m_u + p(s_t+n, a*)(b - l)

        ps_a = ps_a.clamp(min=1e-3)  # Clamp for numerical stability in log
        loss = -torch.sum(
            Variable(m) * ps_a.log(),
            1)  # Cross-entropy loss (minimises DKL(m||p(s_t, a_t)))
        self.online_net.zero_grad()
        (weights * loss).mean().backward()  # Importance weight losses
        self.optimiser.step()

        mem.update_priorities(
            idxs, loss.data)  # Update priorities of sampled transitions

    def update_target_net(self):
        self.target_net.load_state_dict(self.online_net.state_dict())

    def save(self, path):
        torch.save(self.online_net.state_dict(),
                   os.path.join(path, 'model.pth'))

    # Evaluates Q-value based on single state (no batch)
    def evaluate_q(self, state):
        return (self.online_net(state.unsqueeze(0)).data *
                self.support).sum(2).max(1)[0][0]

    def train(self):
        self.online_net.train()

    def eval(self):
        self.online_net.eval()

    def share_memory(self):
        self.online_net.share_memory()