Esempio n. 1
0
def train(traj,
          pol, qfs, targ_qfs, log_alpha,
          optim_pol, optim_qfs, optim_alpha,
          epoch, batch_size,  # optimization hypers
          tau, gamma, sampling, reparam=True,
          log_enable=True,
          max_grad_norm=0.5,
          ):
    """
    Train function for soft actor critic.

    Parameters
    ----------
    traj : Traj
        Off policy trajectory.
    pol : Pol
        Policy.
    qfs : list of SAVfunction
        Q function.
    targ_qfs : list of SAVfunction
        Target Q function.
    log_alpha : torch.Tensor
        Temperature parameter of entropy.
    optim_pol : torch.optim.Optimizer
        Optimizer for Policy.
    optim_qfs : list of torch.optim.Optimizer
        Optimizer for Q function.
    optim_alpha : torch.optim.Optimizer
        Optimizer for alpha.
    epoch : int
        Number of iteration.
    batch_size : int
        Number of batches.
    tau : float
        Target updating rate.
    gamma : float
        Discounting rate.
    sampling : int
        Number of samping in calculating expectation.
    reparam : bool
    log_enable: bool
        If True, enable logging
    max_grad_norm : float
        Maximum gradient norm.

    Returns
    -------
    result_dict : dict
        Dictionary which contains losses information.
    """

    pol_losses = []
    _qf_losses = []
    alpha_losses = []
    if log_enable:
        logger.log("Optimizing...")
    for batch in traj.random_batch(batch_size, epoch):
        pol_loss, qf_losses, alpha_loss = lf.sac(
            pol, qfs, targ_qfs, log_alpha, batch, gamma, sampling, reparam)

        optim_pol.zero_grad()
        pol_loss.backward()
        torch.nn.utils.clip_grad_norm_(pol.parameters(), max_grad_norm)
        optim_pol.step()

        for qf, optim_qf, qf_loss in zip(qfs, optim_qfs, qf_losses):
            optim_qf.zero_grad()
            qf_loss.backward()
            torch.nn.utils.clip_grad_norm_(qf.parameters(), max_grad_norm)
            optim_qf.step()

        optim_alpha.zero_grad()
        alpha_loss.backward()
        optim_alpha.step()

        for qf, targ_qf in zip(qfs, targ_qfs):
            for q, targ_q in zip(qf.parameters(), targ_qf.parameters()):
                targ_q.detach().copy_((1 - tau) * targ_q.detach() + tau * q.detach())

        pol_losses.append(pol_loss.detach().cpu().numpy())
        _qf_losses.append(
            (sum(qf_losses) / len(qf_losses)).detach().cpu().numpy())
        alpha_losses.append(alpha_loss.detach().cpu().numpy())

    if log_enable:
        logger.log("Optimization finished!")

    return dict(
        PolLoss=pol_losses,
        QfLoss=_qf_losses,
        AlphaLoss=alpha_losses
    )
Esempio n. 2
0
def train(
        traj,
        pol,
        qfs,
        targ_qfs,
        log_alpha,
        optim_pol,
        optim_qfs,
        optim_alpha,
        epoch,
        batch_size,  # optimization hypers
        tau,
        gamma,
        sampling,
        discrim,
        num_skill,
        reparam=True):
    """
    Train function for soft actor critic.

    Parameters
    ----------
    traj : Traj
        Off policy trajectory.
    pol : Pol
        Policy.
    qfs : list of SAVfunction
        Q function.
    targ_qfs : list of SAVfunction
        Target Q function.
    log_alpha : torch.Tensor
        Temperature parameter of entropy.
    optim_pol : torch.optim.Optimizer
        Optimizer for Policy.
    optim_qfs : list of torch.optim.Optimizer
        Optimizer for Q function.
    optim_alpha : torch.optim.Optimizer
        Optimizer for alpha.
    epoch : int
        Number of iteration.
    batch_size : int
        Number of batches.
    tau : float
        Target updating rate.
    gamma : float
        Discounting rate.
    sampling : int
        Number of samping in calculating expectation.
    reparam : bool

    discrim : SVfunction
        Discriminator.
    discrim_f :  function 
        Feature extractor of discriminator.
    f_dim :  
        The dimention of discrim_f output.
    num_skill : int
        The number of skills.

    Returns
    -------
    result_dict : dict
        Dictionary which contains losses information.
    """

    pol_losses = []
    _qf_losses = []
    alpha_losses = []
    logger.log("Optimizing...")
    for batch in traj.random_batch(batch_size, epoch):
        with torch.no_grad():
            rews, info = calc_rewards(batch['obs'], num_skill, discrim)
            batch['rews'] = rews

        pol_loss, qf_losses, alpha_loss = lf.sac(pol, qfs, targ_qfs, log_alpha,
                                                 batch, gamma, sampling,
                                                 reparam)

        optim_pol.zero_grad()
        pol_loss.backward()
        optim_pol.step()

        for optim_qf, qf_loss in zip(optim_qfs, qf_losses):
            optim_qf.zero_grad()
            qf_loss.backward()
            optim_qf.step()

        optim_alpha.zero_grad()
        alpha_loss.backward()
        optim_alpha.step()

        for qf, targ_qf in zip(qfs, targ_qfs):
            for q, targ_q in zip(qf.parameters(), targ_qf.parameters()):
                targ_q.detach().copy_((1 - tau) * targ_q.detach() +
                                      tau * q.detach())

        pol_losses.append(pol_loss.detach().cpu().numpy())
        _qf_losses.append(
            (sum(qf_losses) / len(qf_losses)).detach().cpu().numpy())
        alpha_losses.append(alpha_loss.detach().cpu().numpy())

    logger.log("Optimization finished!")

    return dict(PolLoss=pol_losses, QfLoss=_qf_losses, AlphaLoss=alpha_losses)