def update(replay, optimizer, policy, env): # Compute advantages _, next_state_value = policy(replay[-1].next_state) rewards = ch.discount(GAMMA, replay.reward(), replay.done(), bootstrap=next_state_value) rewards = rewards.detach() # Compute loss entropy = replay.entropy().mean() advantages = rewards.detach() - replay.value().detach() policy_loss = a2c.policy_loss(replay.log_prob(), advantages) value_loss = a2c.state_value_loss(replay.value(), rewards) loss = policy_loss + V_WEIGHT * value_loss - ENT_WEIGHT * entropy # Take optimization step optimizer.zero_grad() loss.backward() nn.utils.clip_grad_norm_(policy.parameters(), GRAD_NORM) optimizer.step() env.log('policy loss', policy_loss.item()) env.log('value loss', value_loss.item()) env.log('entropy', entropy.item())
def maml_a2c_loss(train_episodes, learner, baseline, gamma, tau): # Update policy and baseline states = train_episodes.state() actions = train_episodes.action() rewards = train_episodes.reward() dones = train_episodes.done() next_states = train_episodes.next_state() log_probs = learner.log_prob(states, actions) advantages = compute_advantages(baseline, tau, gamma, rewards, dones, states, next_states) advantages = ch.normalize(advantages).detach() return a2c.policy_loss(log_probs, advantages)
def maml_a2c_loss(train_episodes, learner, baseline, gamma, tau): # Update policy and baseline states = train_episodes.state() actions = train_episodes.action() rewards = train_episodes.reward() dones = train_episodes.done() next_states = train_episodes.next_state() log_probs = learner.log_prob(states, actions) weights = th.ones_like(dones) weights[1:].add_(-1.0, dones[:-1]) weights /= dones.sum() cum_log_probs = weighted_cumsum(log_probs, weights) advantages = compute_advantages(baseline, tau, gamma, rewards, dones, states, next_states) return a2c.policy_loss(l2l.magic_box(cum_log_probs), advantages)
def trpo_a2c_loss(episodes, learner, baseline, gamma, tau, update_vf=True): # Get values to device states, actions, rewards, dones, next_states = get_episode_values(episodes) # Calculate loss between states and action in the network log_probs = learner.log_prob(states, actions) # Compute advantages & normalize advantages = compute_advantages(baseline, tau, gamma, rewards, dones, states, next_states, update_vf=update_vf) advantages = ch.normalize(advantages).detach() # Compute the policy loss return a2c.policy_loss(log_probs, advantages)
def vpg_a2c_loss(episodes, learner, baseline, gamma, tau, dice=False): # Get values to device states, actions, rewards, dones, next_states = get_episode_values(episodes) # Calculate loss between states and action in the network log_probs = learner.log_prob(states, actions) # Fit value function, compute advantages & normalize advantages = compute_advantages(baseline, tau, gamma, rewards, dones, states, next_states) # Calculate DiCE objective if dice: weights = torch.ones_like(dones) weights[1:].add_(dones[:-1], alpha=-1.0) weights /= dones.sum() cum_log_probs = weighted_cumsum(log_probs, weights) log_probs = magic_box(cum_log_probs) return a2c.policy_loss(log_probs, advantages)
def pg_loss(train_episodes, learner, baseline, discount): # computes pg loss states = train_episodes.state() actions = train_episodes.action() rewards = train_episodes.reward() dones = train_episodes.done() next_states = train_episodes.next_state() log_probs = learner.log_prob(states, actions) weights = torch.ones_like(dones) weights[1:].add_(-1.0, dones[:-1]) weights /= dones.sum() def weighted_cumulative_sum(values, weights): for i in range(values.size(0)): values[i] += values[i - 1] * weights[i] return values cum_log_probs = weighted_cumulative_sum(log_probs, weights) advantages = compute_advantages(baseline, discount, rewards, dones, states, next_states) return a2c.policy_loss(l2l.magic_box(cum_log_probs), advantages)
def maml_a2c_loss(train_episodes, learner, baseline, gamma, tau): # Update policy and baseline rewards = train_episodes.reward() states = train_episodes.state() densities = learner(states)[1]['density'] log_probs = densities.log_prob(train_episodes.action()) log_probs = log_probs.mean(dim=1, keepdim=True) dones = train_episodes.done() # Update baseline returns = ch.td.discount(gamma, rewards, dones) baseline.fit(states, returns) values = baseline(states) # Update model advantages = ch.pg.generalized_advantage(tau=tau, gamma=gamma, rewards=rewards, dones=dones, values=values, next_value=th.zeros(1)) return a2c.policy_loss(log_probs, advantages)