예제 #1
0
def _shallow_reinforce_estimator(batch, disc, policy, baselinekind='peters', result='mean'):
    with torch.no_grad():        
        states, actions, rewards, mask, _ = unpack(batch) #NxHxm, NxHxd, NxH, NxH
        
        scores = policy.score(states, actions) #NxHxm
        scores = tensormat(scores, mask) #NxHxm
        G = torch.sum(scores, 1) #Nxm
        disc_rewards = discount(rewards, disc) #NxH
        rets = torch.sum(disc_rewards, 1) #N
        
        if baselinekind == 'avg':
            baseline = torch.mean(rets, 0) #1
        elif baselinekind == 'peters':
            baseline = torch.mean(G ** 2 * rets.unsqueeze(1), 0) /\
                torch.mean(G ** 2, 0) #m
        else:
            baseline = torch.zeros(1) #1
        baseline[baseline != baseline] = 0
        values = rets.unsqueeze(1) - baseline.unsqueeze(0) #Nxm
        
        _samples = G * values #Nxm
        if result == 'samples':
            return _samples #Nxm
        else:
            return torch.mean(_samples, 0) #m
예제 #2
0
def reinforce_estimator(batch, disc, policy, baselinekind='avg', 
                        result='mean', shallow=False):
    """REINFORCE policy gradient estimator
       
    batch: list of N trajectories. Each trajectory is a tuple 
        (states, actions, rewards, mask). Each element of the tuple is a 
        tensor where the first dimension is time.
    disc: discount factor
    policy: the one used to collect the data
    baselinekind: kind of baseline to employ in the estimator. 
        Either 'avg' (average reward, default), 'peters' 
        (variance-minimizing),  or 'zero' (no baseline)
    result: whether to return the final estimate ('mean', default), or the 
        single per-trajectory estimates ('samples')
    shallow: whether to use precomputed score functions (only available
        for shallow policies)
    """
    if shallow:
        return _shallow_reinforce_estimator(batch, disc, policy, baselinekind, result)
    
    N = len(batch)
    states, actions, rewards, mask, _ = unpack(batch) #NxHxm, NxHxd, NxH, NxH
    
    disc_rewards = discount(rewards, disc) #NxH
    rets = torch.sum(disc_rewards, 1) #N
    logps = policy.log_pdf(states, actions) * mask #NxH
    
    if baselinekind == 'peters':
        logp_sums = torch.sum(logps, 1) #N
        jac = jacobian(policy, logp_sums) #Nxm   
        b_num = torch.sum(jac ** 2 * rets.unsqueeze(1), 0) #m
        b_den = torch.sum(jac **2, 0) #m
        baseline = b_num / b_den #m
        baseline[baseline != baseline] = 0
        values = rets.unsqueeze(1) - baseline.unsqueeze(0) #Nxm
        _samples = jac * values
    else:
        if baselinekind == 'avg':
            baseline = torch.mean(rets, 0) #1
        else:
            baseline = torch.zeros(1) #1
        baseline[baseline != baseline] = 0
        values = rets - baseline #N
        
        if result == 'mean':
            logp_sums = torch.sum(logps, 1)
            return tu.flat_gradients(policy, logp_sums, values) / N
        
        _samples = torch.stack([tu.flat_gradients(policy, logps[i,:]) * 
                                                     values[i,:]
                                       for i in range(N)], 0) #Nxm
    if result == 'samples':
        return _samples #Nxm
    else:
        return torch.mean(_samples, 0) #m
예제 #3
0
def gpomdp_estimator(batch, disc, policy, baselinekind='avg', result='mean',
                     shallow=False):
    """G(PO)MDP policy gradient estimator
       
    batch: list of N trajectories. Each trajectory is a tuple 
        (states, actions, rewards, mask). Each element of the tuple is a 
        tensor where the first dimension is time.
    disc: discount factor
    policy: the one used to collect the data
    baselinekind: kind of baseline to employ in the estimator. 
        Either 'avg' (average reward, default), 'peters' 
        (variance-minimizing),  or 'zero' (no baseline)
    result: whether to return the final estimate ('mean', default), or the 
        single per-trajectory estimates ('samples')
    shallow: whether to use precomputed score functions (only available
        for shallow policies)
    """
    if shallow:
        return _shallow_gpomdp_estimator(batch, disc, policy, baselinekind, result)
    
    N = len(batch)
    states, actions, rewards, mask, _ = unpack(batch) #NxHxd_s, NxHxd_a, NxH, NxH
    H = rewards.shape[1]
    m = policy.num_params()
    
    disc_rewards = discount(rewards, disc) #NxH
    logps = policy.log_pdf(states, actions) * mask #NxH
    cum_logps = torch.cumsum(logps, 1) #NxH
    
    if baselinekind == 'peters':
        jac = jacobian(policy, cum_logps.view(-1)).reshape((N,H,m)) #NxHxm   
        b_num = torch.sum(tensormat(jac**2, 
                                    disc_rewards), 0) #Hxm
        b_den = torch.sum(jac**2, 0) #Hxm
        baseline = b_num / b_den #Hxm
        baseline[baseline != baseline] = 0
        values = disc_rewards.unsqueeze(2) - baseline.unsqueeze(0) #NxHxm
        _samples = torch.sum(tensormat(values * jac, mask), 1) #Nxm
    else:
        if baselinekind == 'avg':
            baseline = torch.mean(disc_rewards, 0) #H
        else:
            baseline = torch.zeros(1) #1
        values = (disc_rewards - baseline) * mask #NxH
        
        _samples = torch.stack([tu.flat_gradients(policy, cum_logps[i,:], 
                                              values[i,:])
                                       for i in range(N)], 0) #Nxm
    if result == 'samples':
        return _samples #Nxm
    else:
        return torch.mean(_samples, 0) #m
예제 #4
0
def _shallow_gpomdp_estimator(batch, disc, policy, baselinekind='peters', result='mean'):
    with torch.no_grad():        
        states, actions, rewards, mask, _ = unpack(batch) # NxHxm, NxHxd, NxH, NxH
        
        disc_rewards = discount(rewards, disc) #NxH
        scores = policy.score(states, actions) #NxHxM
        G = torch.cumsum(tensormat(scores, mask), 1) #NxHxm
        
        if baselinekind == 'avg':
            baseline = torch.mean(disc_rewards, 0).unsqueeze(1) #Hx1
        elif baselinekind == 'peters':
            baseline = torch.sum(tensormat(G ** 2, disc_rewards), 0) / \
                            torch.sum(G ** 2, 0) #Hxm
        else:
            baseline = torch.zeros(1,1) #1x1
        baseline[baseline != baseline] = 0
        values = disc_rewards.unsqueeze(2) - baseline.unsqueeze(0) #NxHxm
        
        _samples = torch.sum(tensormat(G * values, mask), 1) #Nxm
        if result == 'samples':
            return _samples #Nxm
        else:
            return torch.mean(_samples, 0) #m
예제 #5
0
def metagrad(batch, disc, policy, alpha, result='mean', grad_samples=None):
    sigma = torch.exp(policy.get_scale_params())

    if grad_samples is None:
        grad_samples = gpomdp_estimator(batch,
                                        disc,
                                        policy,
                                        baselinekind='peters',
                                        shallow=True,
                                        result='samples')  #Nx(m+1)
    with torch.no_grad():
        upsilon_grad = grad_samples[:, 1:]  #Nxm
        omega_grad = grad_samples[:, 0]  #N

        states, actions, rewards, mask, _ = unpack(batch)
        disc_rewards = discount(rewards, disc)

        mix = mix_estimator(states,
                            actions,
                            disc_rewards,
                            mask,
                            policy,
                            result='samples')  #Nxm
        mixed_der = mix - 2 * upsilon_grad  #Nxm
        grad_norm = torch.sqrt(
            torch.bmm(upsilon_grad.unsqueeze(1),
                      upsilon_grad.unsqueeze(2)).view(-1))
        norm_grad = torch.bmm(upsilon_grad.unsqueeze(1),
                              mixed_der.unsqueeze(2)).view(-1) / grad_norm  #N
        A = omega_grad  #N
        B = 2 * alpha * sigma**2 * grad_norm  #N
        C = alpha * sigma**2 * norm_grad  #N
        samples = A + B + C  #N
        if result == 'samples':
            return samples
        else:
            return torch.mean(samples, 0)