def maml_a2c_loss(train_episodes, learner, baseline, gamma, tau): # Update policy and baseline states = train_episodes.state() actions = train_episodes.action() rewards = train_episodes.reward() dones = train_episodes.done() next_states = train_episodes.next_state() log_probs = learner.log_prob(states, actions) weights = th.ones_like(dones) weights[1:].add_(-1.0, dones[:-1]) weights /= dones.sum() cum_log_probs = weighted_cumsum(log_probs, weights) advantages = compute_advantages(baseline, tau, gamma, rewards, dones, states, next_states) return a2c.policy_loss(l2l.magic_box(cum_log_probs), advantages)
def vpg_a2c_loss(episodes, learner, baseline, gamma, tau, dice=False): # Get values to device states, actions, rewards, dones, next_states = get_episode_values(episodes) # Calculate loss between states and action in the network log_probs = learner.log_prob(states, actions) # Fit value function, compute advantages & normalize advantages = compute_advantages(baseline, tau, gamma, rewards, dones, states, next_states) # Calculate DiCE objective if dice: weights = torch.ones_like(dones) weights[1:].add_(dones[:-1], alpha=-1.0) weights /= dones.sum() cum_log_probs = weighted_cumsum(log_probs, weights) log_probs = magic_box(cum_log_probs) return a2c.policy_loss(log_probs, advantages)
def pg_loss(train_episodes, learner, baseline, discount): # computes pg loss states = train_episodes.state() actions = train_episodes.action() rewards = train_episodes.reward() dones = train_episodes.done() next_states = train_episodes.next_state() log_probs = learner.log_prob(states, actions) weights = torch.ones_like(dones) weights[1:].add_(-1.0, dones[:-1]) weights /= dones.sum() def weighted_cumulative_sum(values, weights): for i in range(values.size(0)): values[i] += values[i - 1] * weights[i] return values cum_log_probs = weighted_cumulative_sum(log_probs, weights) advantages = compute_advantages(baseline, discount, rewards, dones, states, next_states) return a2c.policy_loss(l2l.magic_box(cum_log_probs), advantages)