Example #1
0
def ddpg_step(policy_net, policy_net_target, value_net, value_net_target, optimizer_policy, optimizer_value,
              states, actions, rewards, next_states, masks, gamma, polyak):
    masks = masks.unsqueeze(-1)
    rewards = rewards.unsqueeze(-1)
    """update critic"""

    values = value_net(states, actions)

    with torch.no_grad():
        target_next_values = value_net_target(next_states, policy_net_target(next_states))
        target_values = rewards + gamma * masks * target_next_values
    value_loss = nn.MSELoss()(values, target_values)

    optimizer_value.zero_grad()
    value_loss.backward()
    optimizer_value.step()

    """update actor"""

    policy_loss = - value_net(states, policy_net(states)).mean()
    optimizer_policy.zero_grad()
    policy_loss.backward()
    optimizer_policy.step()

    """soft update target nets"""
    policy_net_flat_params = get_flat_params(policy_net)
    policy_net_target_flat_params = get_flat_params(policy_net_target)
    set_flat_params(policy_net_target, polyak * policy_net_target_flat_params + (1 - polyak) * policy_net_flat_params)

    value_net_flat_params = get_flat_params(value_net)
    value_net_target_flat_params = get_flat_params(value_net_target)
    set_flat_params(value_net_target, polyak * value_net_target_flat_params + (1 - polyak) * value_net_flat_params)

    return value_loss, policy_loss
Example #2
0
def line_search(model,
                f,
                x,
                step_dir,
                expected_improve,
                max_backtracks=10,
                accept_ratio=0.1):
    """
    max f(x) <=> min -f(x)
    line search sufficient condition: -f(x_new) <= -f(x) + -e coeff * step_dir
    perform line search method for choosing step size
    :param model:
    :param f:
    :param x:
    :param step_dir: direction to update model parameters
    :param expected_improve:
    :param max_backtracks:
    :param accept_ratio:
    :return:
    """
    f_val = f(False).item()

    for step_coefficient in [.5**k for k in range(max_backtracks)]:
        x_new = x + step_coefficient * step_dir
        set_flat_params(model, x_new)
        f_val_new = f(False).item()
        actual_improve = f_val_new - f_val
        improve = expected_improve * step_coefficient
        ratio = actual_improve / improve
        if ratio > accept_ratio:
            return True, x_new
    return False, x
Example #3
0
def update_policy(policy_net: nn.Module, states, actions, old_log_probs,
                  advantages, max_kl, damping):
    def get_loss(grad=True):
        log_probs = policy_net.get_log_prob(states, actions)
        if not grad:
            log_probs = log_probs.detach()
        ratio = torch.exp(log_probs - old_log_probs)
        loss = (ratio * advantages).mean()
        return loss

    def Hvp(v):
        """
        compute vector product of second order derivative of KL_Divergence Hessian and v
        :param v: vector
        :return: \nabla \nabla H @ v
        """
        # compute kl divergence between current policy and old policy
        kl = policy_net.get_kl(states)
        kl = kl.mean()

        # first order gradient kl
        grads = torch.autograd.grad(kl,
                                    policy_net.parameters(),
                                    create_graph=True)
        flat_grad_kl = torch.cat([grad.view(-1) for grad in grads])

        kl_v = (flat_grad_kl * v).sum()  # flag_grad_kl.T @ v
        # second order gradient of kl
        grads = torch.autograd.grad(kl_v, policy_net.parameters())
        flat_grad_grad_kl = torch.cat(
            [grad.contiguous().view(-1) for grad in grads]).detach()

        return flat_grad_grad_kl + v * damping

    # compute first order approximation to Loss
    loss = get_loss()
    loss_grads = autograd.grad(loss, policy_net.parameters())
    loss_grad = torch.cat([grad.view(-1)
                           for grad in loss_grads]).detach()  # g.T

    # conjugate gradient solve : Hx = g
    # apply vector product strategy here to compute Hx by `Hvp`
    # approximation solution of x'= H^(-1)g
    step_dir = conjugate_gradient(Hvp, loss_grad)
    # g.T H^(-1) g; another implementation: Hvp(step_dir) @ step_dir
    shs = Hvp(step_dir).t() @ step_dir
    lm = torch.sqrt(2 * max_kl / shs)
    step = lm * step_dir  # update direction for policy nets
    expected_improve = loss_grad.t() @ step
    """
    line search for step size 
    """
    current_flat_parameters = get_flat_params(policy_net)  # theta
    success, new_flat_parameters = line_search(policy_net, get_loss,
                                               current_flat_parameters, step,
                                               expected_improve, 10)
    set_flat_params(policy_net, new_flat_parameters)
    # success indicating whether TRPO works as expected
    return success
Example #4
0
    def value_objective_grad_func(value_net_flat_params):
        set_flat_params(value_net, DOUBLE(value_net_flat_params))
        for param in value_net.parameters():
            if param.grad is not None:
                param.grad.data.fill_(0)
        values_pred = value_net(states)
        value_loss = nn.MSELoss()(values_pred, returns)
        # weight decay
        for param in value_net.parameters():
            value_loss += param.pow(2).sum() * l2_reg

        value_loss.backward()  # to get the grad
        objective_value_loss_grad = get_flat_grad_params(
            value_net).detach().cpu().numpy()
        return objective_value_loss_grad
Example #5
0
    def value_objective_func(value_net_flat_params):
        """
        get value_net loss
        :param value_net_flat_params: numpy
        :return:
        """
        set_flat_params(value_net, FLOAT(value_net_flat_params))
        values_pred = value_net(states)
        value_loss = nn.MSELoss()(values_pred, returns)
        # weight decay
        for param in value_net.parameters():
            value_loss += param.pow(2).sum() * l2_reg

        objective_value_loss = value_loss.item()
        # print("Current value loss: ", objective_value_loss)
        return objective_value_loss
Example #6
0
    def value_objective_grad_func(value_net_flat_params):
        """
        objective function for scipy optimizing 
        """
        set_flat_params(value_net, FLOAT(value_net_flat_params))
        for param in value_net.parameters():
            if param.grad is not None:
                param.grad.data.fill_(0)
        values_pred = value_net(states)
        value_loss = nn.MSELoss()(values_pred, returns)
        # weight decay
        for param in value_net.parameters():
            value_loss += param.pow(2).sum() * l2_reg

        value_loss.backward()  # to get the grad
        objective_value_loss_grad = get_flat_grad_params(
            value_net).detach().cpu().numpy().astype(np.float64)
        return objective_value_loss_grad
Example #7
0
def trpo_step(policy_net,
              value_net,
              states,
              actions,
              returns,
              advantages,
              old_log_probs,
              max_kl,
              damping,
              l2_reg,
              optimizer_value=None):
    """
    Update by TRPO algorithm
    """
    """update critic"""
    def value_objective_func(value_net_flat_params):
        """
        get value_net loss
        :param value_net_flat_params: numpy
        :return:
        """
        set_flat_params(value_net, FLOAT(value_net_flat_params))
        values_pred = value_net(states)
        value_loss = nn.MSELoss()(values_pred, returns)
        # weight decay
        for param in value_net.parameters():
            value_loss += param.pow(2).sum() * l2_reg

        objective_value_loss = value_loss.item()
        # print("Current value loss: ", objective_value_loss)
        return objective_value_loss

    def value_objective_grad_func(value_net_flat_params):
        """
        objective function for scipy optimizing 
        """
        set_flat_params(value_net, FLOAT(value_net_flat_params))
        for param in value_net.parameters():
            if param.grad is not None:
                param.grad.data.fill_(0)
        values_pred = value_net(states)
        value_loss = nn.MSELoss()(values_pred, returns)
        # weight decay
        for param in value_net.parameters():
            value_loss += param.pow(2).sum() * l2_reg

        value_loss.backward()  # to get the grad
        objective_value_loss_grad = get_flat_grad_params(
            value_net).detach().cpu().numpy().astype(np.float64)
        return objective_value_loss_grad

    if optimizer_value is None:
        """ 
        update by scipy optimizing, for detail about L-BFGS-B: ref: 
        https://docs.scipy.org/doc/scipy/reference/optimize.minimize-lbfgsb.html#optimize-minimize-lbfgsb
        """
        value_net_flat_params_old = get_flat_params(
            value_net).detach().cpu().numpy().astype(
                np.float64)  # initial guess
        res = opt.minimize(value_objective_func,
                           value_net_flat_params_old,
                           method='L-BFGS-B',
                           jac=value_objective_grad_func,
                           options={
                               "maxiter": 30,
                               "disp": False
                           })
        # print("Call L-BFGS-B, result: ", res)
        value_net_flat_params_new = res.x
        set_flat_params(value_net, FLOAT(value_net_flat_params_new))

    else:
        """
        update by gradient descent
        """
        for _ in range(10):
            values_pred = value_net(states)
            value_loss = nn.MSELoss()(values_pred, returns)
            # weight decay
            for param in value_net.parameters():
                value_loss += param.pow(2).sum() * l2_reg
            optimizer_value.zero_grad()
            value_loss.backward()
            optimizer_value.step()
    """update policy"""
    update_policy(policy_net, states, actions, old_log_probs, advantages,
                  max_kl, damping)
Example #8
0
def sac_alpha_step(policy_net,
                   q_net_1,
                   q_net_2,
                   alpha,
                   q_net_target_1,
                   q_net_target_2,
                   optimizer_policy,
                   optimizer_q_net_1,
                   optimizer_q_net_2,
                   optimizer_a,
                   states,
                   actions,
                   rewards,
                   next_states,
                   masks,
                   gamma,
                   polyak,
                   target_entropy,
                   update_target=False):
    rewards = rewards.unsqueeze(-1)
    masks = masks.unsqueeze(-1)
    """update qvalue net"""
    with torch.no_grad():
        next_actions, next_log_probs = policy_net.rsample(next_states)
        target_q_value = torch.min(
            q_net_target_1(next_states, next_actions),
            q_net_target_2(next_states, next_actions)) - alpha * next_log_probs
        target_q_value = rewards + gamma * masks * target_q_value

    q_value_1 = q_net_1(states, actions)
    q_value_loss_1 = nn.MSELoss()(q_value_1, target_q_value)
    optimizer_q_net_1.zero_grad()
    q_value_loss_1.backward()
    optimizer_q_net_1.step()

    q_value_2 = q_net_2(states, actions)
    q_value_loss_2 = nn.MSELoss()(q_value_2, target_q_value)
    optimizer_q_net_2.zero_grad()
    q_value_loss_2.backward()
    optimizer_q_net_2.step()
    """update policy net"""
    new_actions, log_probs = policy_net.rsample(states)
    min_q = torch.min(q_net_1(states, new_actions),
                      q_net_2(states, new_actions))

    policy_loss = (alpha * log_probs - min_q).mean()
    optimizer_policy.zero_grad()
    policy_loss.backward()
    optimizer_policy.step()
    """update alpha"""
    alpha_loss = -alpha * (log_probs.detach() + target_entropy).mean()
    optimizer_a.zero_grad()
    alpha_loss.backward()
    optimizer_a.step()

    if update_target:
        """ 
        soft update target qvalue net 
        """
        q_net_1_flat_params = get_flat_params(q_net_1)
        q_net_target_1_flat_params = get_flat_params(q_net_target_1)

        set_flat_params(q_net_target_1, (1 - polyak) * q_net_1_flat_params +
                        polyak * q_net_target_1_flat_params)

        q_net_2_flat_params = get_flat_params(q_net_2)
        q_net_target_2_flat_params = get_flat_params(q_net_target_2)

        set_flat_params(q_net_target_2, (1 - polyak) * q_net_2_flat_params +
                        polyak * q_net_target_2_flat_params)

    return {
        "q_value_loss_1": q_value_loss_1,
        "q_value_loss_2": q_value_loss_2,
        "policy_loss": policy_loss,
        "alpha_loss": alpha_loss
    }
Example #9
0
def sac_step(policy_net,
             value_net,
             value_net_target,
             q_net_1,
             q_net_2,
             optimizer_policy,
             optimizer_value,
             optimizer_q_net_1,
             optimizer_q_net_2,
             states,
             actions,
             rewards,
             next_states,
             masks,
             gamma,
             polyak,
             update_target=False):
    rewards = rewards.unsqueeze(-1)
    masks = masks.unsqueeze(-1)
    """update qvalue net"""

    q_value_1 = q_net_1(states, actions)
    q_value_2 = q_net_2(states, actions)
    with torch.no_grad():
        target_next_value = rewards + gamma * \
            masks * value_net_target(next_states)

    q_value_loss_1 = nn.MSELoss()(q_value_1, target_next_value)
    optimizer_q_net_1.zero_grad()
    q_value_loss_1.backward()
    optimizer_q_net_1.step()

    q_value_loss_2 = nn.MSELoss()(q_value_2, target_next_value)
    optimizer_q_net_2.zero_grad()
    q_value_loss_2.backward()
    optimizer_q_net_2.step()
    """update policy net"""
    new_actions, log_probs = policy_net.rsample(states)
    min_q = torch.min(q_net_1(states, new_actions),
                      q_net_2(states, new_actions))
    policy_loss = (log_probs - min_q).mean()
    optimizer_policy.zero_grad()
    policy_loss.backward()
    optimizer_policy.step()
    """update value net"""
    target_value = (min_q - log_probs).detach()
    value_loss = nn.MSELoss()(value_net(states), target_value)
    optimizer_value.zero_grad()
    value_loss.backward()
    optimizer_value.step()

    if update_target:
        """ update target value net """
        value_net_target_flat_params = get_flat_params(value_net_target)
        value_net_flat_params = get_flat_params(value_net)

        set_flat_params(value_net_target,
                        (1 - polyak) * value_net_flat_params +
                        polyak * value_net_target_flat_params)

    return {
        "target_value_loss": value_loss,
        "q_value_loss_1": q_value_loss_1,
        "q_value_loss_2": q_value_loss_2,
        "policy_loss": policy_loss
    }
Example #10
0
def td3_step(policy_net,
             policy_net_target,
             value_net_1,
             value_net_target_1,
             value_net_2,
             value_net_target_2,
             optimizer_policy,
             optimizer_value_1,
             optimizer_value_2,
             states,
             actions,
             rewards,
             next_states,
             masks,
             gamma,
             polyak,
             target_action_noise_std,
             target_action_noise_clip,
             action_high,
             update_policy=False):
    rewards = rewards.unsqueeze(-1)
    masks = masks.unsqueeze(-1)
    """update critic"""
    with torch.no_grad():
        target_action = policy_net_target(next_states)
        target_action_noise = torch.clamp(
            torch.randn_like(target_action) * target_action_noise_std,
            -target_action_noise_clip, target_action_noise_clip)
        target_action = torch.clamp(target_action + target_action_noise,
                                    -action_high, action_high)
        target_values = rewards + gamma * masks * torch.min(
            value_net_target_1(next_states, target_action),
            value_net_target_2(next_states, target_action))
    """update value1 target"""
    values_1 = value_net_1(states, actions)
    value_loss_1 = nn.MSELoss()(target_values, values_1)

    optimizer_value_1.zero_grad()
    value_loss_1.backward()
    optimizer_value_1.step()
    """update value2 target"""
    values_2 = value_net_2(states, actions)
    value_loss_2 = nn.MSELoss()(target_values, values_2)

    optimizer_value_2.zero_grad()
    value_loss_2.backward()
    optimizer_value_2.step()

    policy_loss = None
    if update_policy:
        """update policy"""
        policy_loss = -value_net_1(states, policy_net(states)).mean()

        optimizer_policy.zero_grad()
        policy_loss.backward()
        optimizer_policy.step()
        """soft update target nets"""
        policy_net_flat_params = get_flat_params(policy_net)
        policy_net_target_flat_params = get_flat_params(policy_net_target)
        set_flat_params(
            policy_net_target, polyak * policy_net_target_flat_params +
            (1 - polyak) * policy_net_flat_params)

        value_net_1_flat_params = get_flat_params(value_net_1)
        value_net_1_target_flat_params = get_flat_params(value_net_target_1)
        set_flat_params(
            value_net_target_1, polyak * value_net_1_target_flat_params +
            (1 - polyak) * value_net_1_flat_params)

        value_net_2_flat_params = get_flat_params(value_net_2)
        value_net_2_target_flat_params = get_flat_params(value_net_target_2)
        set_flat_params(
            value_net_target_2, polyak * value_net_2_target_flat_params +
            (1 - polyak) * value_net_2_flat_params)

    return value_loss_1, value_loss_2, policy_loss