def updateNetwork(self, samples):
        # organize the mini-batch so that we can request "columns" from the data
        # e.g. we can get all of the actions, or all of the states with a single call
        batch = getBatchColumns(samples)

        # compute Q(s, a) for each sample in mini-batch
        Qs, x = self.policy_net(batch.states)
        Qsa = Qs.gather(1, batch.actions).squeeze()

        # by default Q(s', a') = 0 unless the next states are non-terminal
        Qspap = torch.zeros(batch.size, device=device)

        # if we don't have any non-terminal next states, then no need to bootstrap
        if batch.nterm_sp.shape[0] > 0:
            Qsp, _ = self.target_net(batch.nterm_sp)

            # bootstrapping term is the max Q value for the next-state
            # only assign to indices where the next state is non-terminal
            Qspap[batch.nterm] = Qsp.max(1).values

        # compute the empirical MSBE for this mini-batch and let torch auto-diff to optimize
        # don't worry about detaching the bootstrapping term for semi-gradient Q-learning
        # the target network handles that
        target = batch.rewards + batch.gamma * Qspap.detach()
        td_loss = 0.5 * f.mse_loss(target, Qsa)

        # make sure we have no gradients left over from previous update
        self.optimizer.zero_grad()
        self.target_net.zero_grad()

        # compute the entire gradient of the network using only the td error
        td_loss.backward()

        # update the *policy network* using the combined gradients
        self.optimizer.step()
Пример #2
0
    def update(self, s, a, sp, r, gamma):
        self.buffer.add((s, a, sp, r, gamma))
        self.steps += 1

        if self.steps % self.target_refresh == 0 and self.target_refresh > 1:
            cloneNetworkWeights(self.value_net, self.target_net)

        if len(self.buffer) > self.batch_size + 1:
            samples, idcs = self.buffer.sample(self.batch_size)
            batch = getBatchColumns(samples)
            predictions = self.forward(batch)
            tde = self.updateNetwork(batch, predictions)

            self.buffer.update_priorities(idcs, tde)
Пример #3
0
    def updateActionNet(self, samples, q_net, target_q_net, optimizer, storeList):
        batch = getBatchColumns(samples)
        Qs, x = q_net(batch.states)

        # Qsa = Qs.squeeze()
        # for i in range(len(batch.actions)):
        #     storeList.append(Qsa.detach().numpy()[i])
        Qspap = torch.zeros(batch.size, device=device)

        ############  ============  CHECK ================= ###############################
        if batch.nterm_sp.shape[0] > 0:
            ##  Qsp, _ = target_q_net(batch.nterm_sp) #### Is this correct ????

            Qsp_back, _ = self.back_target_q_net(batch.nterm_sp)
            Qsp_stay, _ = self.stay_target_q_net(batch.nterm_sp)
            Qsp_forward, _ = self.forward_target_q_net(batch.nterm_sp)

            Qsp = torch.hstack([Qsp_back, Qsp_stay, Qsp_forward])

            # bootstrapping term is the max Q value for the next-state
            # only assign to indices where the next state is non-terminal
            Qspap[batch.nterm] = Qsp.max(1).values

        ############  ============  CHECK ================= ###############################
        # compute the empirical MSBE for this mini-batch and let torch auto-diff to optimize
        # don't worry about detaching the bootstrapping term for semi-gradient Q-learning
        # the target network handles that
        target = batch.rewards + batch.gamma * Qspap.detach()
        td_loss = 0.5 * f.mse_loss(target, Qsa)

        # make sure we have no gradients left over from previous update
        optimizer.zero_grad()
        target_q_net.zero_grad()
        self.back_target_q_net.zero_grad()
        self.stay_target_q_net.zero_grad()
        self.forward_target_q_net.zero_grad()

        # compute the entire gradient of the network using only the td error
        td_loss.backward()

        Qs_state_array, _ = q_net(self.state_array)
        Qsa_mean_states = torch.mean(Qs_state_array, 0)
        storeList.append(Qsa_mean_states[0].detach().numpy())

        # update the *policy network* using the combined gradients
        optimizer.step()
Пример #4
0
    def updateNetwork(self, samples):
        self.updates += 1
        # organize the mini-batch so that we can request "columns" from the data
        # e.g. we can get all of the actions, or all of the states with a single call
        batch = getBatchColumns(samples)

        # compute Q(s, a) for each sample in mini-batch
        Qs, x = self.policy_net(batch.states)
        Qsa = Qs.gather(1, batch.actions).squeeze()

        # by default Q(s', a') = 0 unless the next states are non-terminal
        Qspap = torch.zeros(len(samples), device=device)

        # if we don't have any non-terminal next states, then no need to bootstrap
        if batch.nterm_sp.shape[0] > 0:
            Qsp, _ = self.target_net(batch.nterm_sp)

            # bootstrapping term is the max Q value for the next-state
            # only assign to indices where the next state is non-terminal
            Qspap[batch.nterm] = Qsp.max(1).values


        # compute the empirical MSBE for this mini-batch and let torch auto-diff to optimize
        # don't worry about detaching the bootstrapping term for semi-gradient Q-learning
        # the target network handles that
        target = batch.rewards + batch.gamma * Qspap.detach()
        td_loss = 0.5 * f.mse_loss(target, Qsa)

        # compute E[\delta | x] ~= <h, x>
        with torch.no_grad():
            delta_hats = torch.matmul(x, self.h.t())
            delta_hat = delta_hats.gather(1, batch.actions)

        # the gradient correction term is gamma * <h, x> * \nabla_w Q(s', a')
        # to compute this gradient, we use pytorch auto-diff
        correction_loss = torch.mean(batch.gamma * delta_hat * Qspap)

        # make sure we have no gradients left over from previous update
        self.optimizer.zero_grad()
        self.target_net.zero_grad()

        # compute the entire gradient of the network using only the td error
        td_loss.backward()

        # if we have non-terminal states in the mini-batch
        # the compute the correction term using the gradient of the *target network*
        if batch.nterm_sp.shape[0] > 0:
            correction_loss.backward()

        # add the gradients of the target network for the correction term to the gradients for the td error
        for (policy_param, target_param) in zip(self.policy_net.parameters(), self.target_net.parameters()):
            policy_param.grad.add_(target_param.grad)

        # update the *policy network* using the combined gradients
        self.optimizer.step()

        # update the secondary weights using a *fixed* feature representation generated by the policy network
        with torch.no_grad():
            delta = target - Qsa
            dh = (delta - delta_hat) * x

            # compute the update for each action independently
            # assume that there is a separate `h` vector for each individual action
            for a in range(self.actions):
                mask = (batch.actions == a).squeeze(1)

                # if this action was never taken in this mini-batch
                # then skip the update for this action
                if mask.sum() == 0:
                    continue

                # the update for `h` minus the regularizer
                h_update = dh[mask].mean(0) - self.beta * self.h[a]

                # ADAM optimizer with bias correction
                # keep a separate set of weights for each action here as well
                self.v[a] = self.beta_2 * self.v[a] + (1 - self.beta_2) * (h_update**2)
                self.m[a] = self.beta_1 * self.m[a] + (1 - self.beta_1) * h_update

                m = self.m[a] / (1 - self.beta_1**self.updates)
                v = self.v[a] / (1 - self.beta_2**self.updates)

                self.h[a] = self.h[a] + self.alpha * m / (torch.sqrt(v) + self.eps)