def fit(self, dataset): if not self._quiet: tqdm.write('Iteration ' + str(self._iter)) state, action, reward, next_state, absorbing, last = parse_dataset( dataset) x = state.astype(np.float32) u = action.astype(np.float32) r = reward.astype(np.float32) xn = next_state.astype(np.float32) obs = to_float_tensor(x, self.policy.use_cuda) act = to_float_tensor(u, self.policy.use_cuda) v_target, np_adv = compute_gae(self._V, x, xn, r, absorbing, last, self.mdp_info.gamma, self._lambda) np_adv = (np_adv - np.mean(np_adv)) / (np.std(np_adv) + 1e-8) adv = to_float_tensor(np_adv, self.policy.use_cuda) # Policy update self._old_policy = deepcopy(self.policy) old_pol_dist = self._old_policy.distribution_t(obs) old_log_prob = self._old_policy.log_prob_t(obs, act).detach() zero_grad(self.policy.parameters()) loss = self._compute_loss(obs, act, adv, old_log_prob) prev_loss = loss.item() # Compute Gradient loss.backward() g = get_gradient(self.policy.parameters()) # Compute direction through conjugate gradient stepdir = self._conjugate_gradient(g, obs, old_pol_dist) # Line search self._line_search(obs, act, adv, old_log_prob, old_pol_dist, prev_loss, stepdir) # VF update self._V.fit(x, v_target, **self._critic_fit_params) # Print fit information self._print_fit_info(dataset, x, v_target, old_pol_dist) self._iter += 1
def diff(self, *args, **kwargs): """ Compute the derivative of the output w.r.t. ``state``, and ``action`` if provided. Args: state (np.ndarray): the state; action (np.ndarray, None): the action. Returns: The derivative of the output w.r.t. ``state``, and ``action`` if provided. """ if not self._use_cuda: torch_args = [torch.from_numpy(np.atleast_2d(x)) for x in args] else: torch_args = [ torch.from_numpy(np.atleast_2d(x)).cuda() for x in args ] y_hat = self.network(*torch_args, **kwargs) n_outs = 1 if len(y_hat.shape) == 0 else y_hat.shape[-1] y_hat = y_hat.view(-1, n_outs) gradients = list() for i in range(y_hat.shape[1]): zero_grad(self.network.parameters()) y_hat[:, i].backward(retain_graph=True) gradient = list() for p in self.network.parameters(): g = p.grad.data.detach().cpu().numpy() gradient.append(g.flatten()) g = np.concatenate(gradient, 0) gradients.append(g) g = np.stack(gradients, -1) return g