def _shallow_reinforce_estimator(batch, disc, policy, baselinekind='peters', result='mean'): with torch.no_grad(): states, actions, rewards, mask, _ = unpack(batch) #NxHxm, NxHxd, NxH, NxH scores = policy.score(states, actions) #NxHxm scores = tensormat(scores, mask) #NxHxm G = torch.sum(scores, 1) #Nxm disc_rewards = discount(rewards, disc) #NxH rets = torch.sum(disc_rewards, 1) #N if baselinekind == 'avg': baseline = torch.mean(rets, 0) #1 elif baselinekind == 'peters': baseline = torch.mean(G ** 2 * rets.unsqueeze(1), 0) /\ torch.mean(G ** 2, 0) #m else: baseline = torch.zeros(1) #1 baseline[baseline != baseline] = 0 values = rets.unsqueeze(1) - baseline.unsqueeze(0) #Nxm _samples = G * values #Nxm if result == 'samples': return _samples #Nxm else: return torch.mean(_samples, 0) #m
def reinforce_estimator(batch, disc, policy, baselinekind='avg', result='mean', shallow=False): """REINFORCE policy gradient estimator batch: list of N trajectories. Each trajectory is a tuple (states, actions, rewards, mask). Each element of the tuple is a tensor where the first dimension is time. disc: discount factor policy: the one used to collect the data baselinekind: kind of baseline to employ in the estimator. Either 'avg' (average reward, default), 'peters' (variance-minimizing), or 'zero' (no baseline) result: whether to return the final estimate ('mean', default), or the single per-trajectory estimates ('samples') shallow: whether to use precomputed score functions (only available for shallow policies) """ if shallow: return _shallow_reinforce_estimator(batch, disc, policy, baselinekind, result) N = len(batch) states, actions, rewards, mask, _ = unpack(batch) #NxHxm, NxHxd, NxH, NxH disc_rewards = discount(rewards, disc) #NxH rets = torch.sum(disc_rewards, 1) #N logps = policy.log_pdf(states, actions) * mask #NxH if baselinekind == 'peters': logp_sums = torch.sum(logps, 1) #N jac = jacobian(policy, logp_sums) #Nxm b_num = torch.sum(jac ** 2 * rets.unsqueeze(1), 0) #m b_den = torch.sum(jac **2, 0) #m baseline = b_num / b_den #m baseline[baseline != baseline] = 0 values = rets.unsqueeze(1) - baseline.unsqueeze(0) #Nxm _samples = jac * values else: if baselinekind == 'avg': baseline = torch.mean(rets, 0) #1 else: baseline = torch.zeros(1) #1 baseline[baseline != baseline] = 0 values = rets - baseline #N if result == 'mean': logp_sums = torch.sum(logps, 1) return tu.flat_gradients(policy, logp_sums, values) / N _samples = torch.stack([tu.flat_gradients(policy, logps[i,:]) * values[i,:] for i in range(N)], 0) #Nxm if result == 'samples': return _samples #Nxm else: return torch.mean(_samples, 0) #m
def gpomdp_estimator(batch, disc, policy, baselinekind='avg', result='mean', shallow=False): """G(PO)MDP policy gradient estimator batch: list of N trajectories. Each trajectory is a tuple (states, actions, rewards, mask). Each element of the tuple is a tensor where the first dimension is time. disc: discount factor policy: the one used to collect the data baselinekind: kind of baseline to employ in the estimator. Either 'avg' (average reward, default), 'peters' (variance-minimizing), or 'zero' (no baseline) result: whether to return the final estimate ('mean', default), or the single per-trajectory estimates ('samples') shallow: whether to use precomputed score functions (only available for shallow policies) """ if shallow: return _shallow_gpomdp_estimator(batch, disc, policy, baselinekind, result) N = len(batch) states, actions, rewards, mask, _ = unpack(batch) #NxHxd_s, NxHxd_a, NxH, NxH H = rewards.shape[1] m = policy.num_params() disc_rewards = discount(rewards, disc) #NxH logps = policy.log_pdf(states, actions) * mask #NxH cum_logps = torch.cumsum(logps, 1) #NxH if baselinekind == 'peters': jac = jacobian(policy, cum_logps.view(-1)).reshape((N,H,m)) #NxHxm b_num = torch.sum(tensormat(jac**2, disc_rewards), 0) #Hxm b_den = torch.sum(jac**2, 0) #Hxm baseline = b_num / b_den #Hxm baseline[baseline != baseline] = 0 values = disc_rewards.unsqueeze(2) - baseline.unsqueeze(0) #NxHxm _samples = torch.sum(tensormat(values * jac, mask), 1) #Nxm else: if baselinekind == 'avg': baseline = torch.mean(disc_rewards, 0) #H else: baseline = torch.zeros(1) #1 values = (disc_rewards - baseline) * mask #NxH _samples = torch.stack([tu.flat_gradients(policy, cum_logps[i,:], values[i,:]) for i in range(N)], 0) #Nxm if result == 'samples': return _samples #Nxm else: return torch.mean(_samples, 0) #m
def _shallow_gpomdp_estimator(batch, disc, policy, baselinekind='peters', result='mean'): with torch.no_grad(): states, actions, rewards, mask, _ = unpack(batch) # NxHxm, NxHxd, NxH, NxH disc_rewards = discount(rewards, disc) #NxH scores = policy.score(states, actions) #NxHxM G = torch.cumsum(tensormat(scores, mask), 1) #NxHxm if baselinekind == 'avg': baseline = torch.mean(disc_rewards, 0).unsqueeze(1) #Hx1 elif baselinekind == 'peters': baseline = torch.sum(tensormat(G ** 2, disc_rewards), 0) / \ torch.sum(G ** 2, 0) #Hxm else: baseline = torch.zeros(1,1) #1x1 baseline[baseline != baseline] = 0 values = disc_rewards.unsqueeze(2) - baseline.unsqueeze(0) #NxHxm _samples = torch.sum(tensormat(G * values, mask), 1) #Nxm if result == 'samples': return _samples #Nxm else: return torch.mean(_samples, 0) #m
def metagrad(batch, disc, policy, alpha, result='mean', grad_samples=None): sigma = torch.exp(policy.get_scale_params()) if grad_samples is None: grad_samples = gpomdp_estimator(batch, disc, policy, baselinekind='peters', shallow=True, result='samples') #Nx(m+1) with torch.no_grad(): upsilon_grad = grad_samples[:, 1:] #Nxm omega_grad = grad_samples[:, 0] #N states, actions, rewards, mask, _ = unpack(batch) disc_rewards = discount(rewards, disc) mix = mix_estimator(states, actions, disc_rewards, mask, policy, result='samples') #Nxm mixed_der = mix - 2 * upsilon_grad #Nxm grad_norm = torch.sqrt( torch.bmm(upsilon_grad.unsqueeze(1), upsilon_grad.unsqueeze(2)).view(-1)) norm_grad = torch.bmm(upsilon_grad.unsqueeze(1), mixed_der.unsqueeze(2)).view(-1) / grad_norm #N A = omega_grad #N B = 2 * alpha * sigma**2 * grad_norm #N C = alpha * sigma**2 * norm_grad #N samples = A + B + C #N if result == 'samples': return samples else: return torch.mean(samples, 0)