def maml_a2c_loss(train_episodes, learner, baseline, gamma, tau):
    # Update policy and baseline
    states = train_episodes.state()
    actions = train_episodes.action()
    rewards = train_episodes.reward()
    dones = train_episodes.done()
    next_states = train_episodes.next_state()
    log_probs = learner.log_prob(states, actions)
    weights = th.ones_like(dones)
    weights[1:].add_(-1.0, dones[:-1])
    weights /= dones.sum()
    cum_log_probs = weighted_cumsum(log_probs, weights)
    advantages = compute_advantages(baseline, tau, gamma, rewards, dones,
                                    states, next_states)
    return a2c.policy_loss(l2l.magic_box(cum_log_probs), advantages)
예제 #2
0
def vpg_a2c_loss(episodes, learner, baseline, gamma, tau, dice=False):
    # Get values to device
    states, actions, rewards, dones, next_states = get_episode_values(episodes)

    # Calculate loss between states and action in the network
    log_probs = learner.log_prob(states, actions)

    # Fit value function, compute advantages & normalize
    advantages = compute_advantages(baseline, tau, gamma, rewards, dones,
                                    states, next_states)

    # Calculate DiCE objective
    if dice:
        weights = torch.ones_like(dones)
        weights[1:].add_(dones[:-1], alpha=-1.0)
        weights /= dones.sum()
        cum_log_probs = weighted_cumsum(log_probs, weights)
        log_probs = magic_box(cum_log_probs)

    return a2c.policy_loss(log_probs, advantages)
예제 #3
0
def pg_loss(train_episodes, learner, baseline, discount):
    # computes pg loss
    states = train_episodes.state()
    actions = train_episodes.action()
    rewards = train_episodes.reward()
    dones = train_episodes.done()
    next_states = train_episodes.next_state()
    log_probs = learner.log_prob(states, actions)
    weights = torch.ones_like(dones)
    weights[1:].add_(-1.0, dones[:-1])
    weights /= dones.sum()

    def weighted_cumulative_sum(values, weights):
        for i in range(values.size(0)):
            values[i] += values[i - 1] * weights[i]
        return values

    cum_log_probs = weighted_cumulative_sum(log_probs, weights)
    advantages = compute_advantages(baseline, discount, rewards, dones, states,
                                    next_states)
    return a2c.policy_loss(l2l.magic_box(cum_log_probs), advantages)