示例#1
0
def train(
    traj,
    pol,
    targ_pol,
    qf,
    targ_qf,
    optim_pol,
    optim_qf,
    epoch,
    batch_size,  # optimization hypers
    tau,
    gamma,
    log_enable=True,
):

    pol_losses = []
    qf_losses = []
    if log_enable:
        logger.log("Optimizing...")
    for batch, indices in traj.prioritized_random_batch(batch_size,
                                                        epoch,
                                                        return_indices=True):
        qf_bellman_loss = lf.bellman(qf,
                                     targ_qf,
                                     targ_pol,
                                     batch,
                                     gamma,
                                     reduction='none')
        td_loss = torch.sqrt(qf_bellman_loss * 2)
        qf_bellman_loss = torch.mean(qf_bellman_loss)
        optim_qf.zero_grad()
        qf_bellman_loss.backward()
        optim_qf.step()

        pol_loss = lf.ag(pol, qf, batch)
        optim_pol.zero_grad()
        pol_loss.backward()
        optim_pol.step()

        for p, targ_p in zip(pol.parameters(), targ_pol.parameters()):
            targ_p.detach().copy_((1 - tau) * targ_p.detach() +
                                  tau * p.detach())
        for q, targ_q in zip(qf.parameters(), targ_qf.parameters()):
            targ_q.detach().copy_((1 - tau) * targ_q.detach() +
                                  tau * q.detach())

        qf_losses.append(qf_bellman_loss.detach().cpu().numpy())
        pol_losses.append(pol_loss.detach().cpu().numpy())

        traj = tf.update_pris(traj, td_loss, indices)
    if log_enable:
        logger.log("Optimization finished!")

    return {'PolLoss': pol_losses, 'QfLoss': qf_losses}
示例#2
0
def train(
        traj,
        pol,
        targ_pol,
        qf,
        targ_qf,
        optim_pol,
        optim_qf,
        epoch,
        batch_size,  # optimization hypers
        tau,
        gamma  # advantage estimation
):
    """
    Train function for deep deterministic policy gradient

    Parameters
    ----------
    traj : Traj
        Off policy trajectory.
    pol : Pol
        Policy.
    targ_pol : Pol
        Target Policy.
    qf : SAVfunction
        Q function.
    targ_qf : SAVfunction
        Target Q function.
    optim_pol : torch.optim.Optimizer
        Optimizer for Policy.
    optim_qf : torch.optim.Optimizer
        Optimizer for Q function.
    epoch : int
        Number of iteration.
    batch_size : int
        Number of batches.
    tau : float
        Target updating rate.
    gamma : float
        Discounting rate.

    Returns
    -------
    result_dict : dict
        Dictionary which contains losses information.
    """

    pol_losses = []
    qf_losses = []
    logger.log("Optimizing...")
    for batch in traj.random_batch(batch_size, epoch):
        qf_bellman_loss = lf.bellman(qf, targ_qf, targ_pol, batch, gamma)
        optim_qf.zero_grad()
        qf_bellman_loss.backward()
        optim_qf.step()

        pol_loss = lf.ag(pol, qf, batch, no_noise=True)
        optim_pol.zero_grad()
        pol_loss.backward()
        optim_pol.step()

        for p, targ_p in zip(pol.parameters(), targ_pol.parameters()):
            targ_p.detach().copy_((1 - tau) * targ_p.detach() +
                                  tau * p.detach())
        for q, targ_q in zip(qf.parameters(), targ_qf.parameters()):
            targ_q.detach().copy_((1 - tau) * targ_q.detach() +
                                  tau * q.detach())

        qf_losses.append(qf_bellman_loss.detach().cpu().numpy())
        pol_losses.append(pol_loss.detach().cpu().numpy())
    logger.log("Optimization finished!")

    return {'PolLoss': pol_losses, 'QfLoss': qf_losses}
示例#3
0
def train(
    traj,
    pol,
    targ_pol,
    qf,
    targ_qf,
    optim_pol,
    optim_qf,
    epoch,
    batch_size,  # optimization hypers
    tau,
    gamma,  # advantage estimation
    sampling,
    log_enable=True,
):
    """
    Train function for deep deterministic policy gradient

    Parameters
    ----------
    traj : Traj
        Off policy trajectory.
    pol : Pol
        Policy.
    targ_pol : Pol
        Target Policy.
    qf : SAVfunction
        Q function.
    targ_qf : SAVfunction
        Target Q function.
    optim_pol : torch.optim.Optimizer
        Optimizer for Policy.
    optim_qf : torch.optim.Optimizer
        Optimizer for Q function.
    epoch : int
        Number of iteration.
    batch_size : int
        Number of batches.
    tau : float
        Target updating rate.
    gamma : float
        Discounting rate.
    sampling : int
        Number of samping in calculating expectation.
    log_enable: bool
        If True, enable logging

    Returns
    -------
    result_dict : dict
        Dictionary which contains losses information.
    """

    pol_losses = []
    qf_losses = []
    if log_enable:
        logger.log("Optimizing...")
    for batch in traj.iterate(batch_size, epoch):
        qf_bellman_loss = lf.bellman(qf,
                                     targ_qf,
                                     targ_pol,
                                     batch,
                                     gamma,
                                     sampling=sampling)
        optim_qf.zero_grad()
        qf_bellman_loss.backward()
        optim_qf.step()

        pol_loss = lf.ag(pol, qf, batch, sampling)
        optim_pol.zero_grad()
        pol_loss.backward()
        optim_pol.step()

        for q, targ_q, p, targ_p in zip(qf.parameters(), targ_qf.parameters(),
                                        pol.parameters(),
                                        targ_pol.parameters()):
            targ_p.detach().copy_((1 - tau) * targ_p.detach() +
                                  tau * p.detach())
            targ_q.detach().copy_((1 - tau) * targ_q.detach() +
                                  tau * q.detach())
        qf_losses.append(qf_bellman_loss.detach().cpu().numpy())
        pol_losses.append(pol_loss.detach().cpu().numpy())

    if log_enable:
        logger.log("Optimization finished!")

    return dict(
        PolLoss=pol_losses,
        QfLoss=qf_losses,
    )
示例#4
0
def train(
    traj,
    pol,
    targ_pol,
    qfs,
    targ_qfs,
    optim_pol,
    optim_qfs,
    epoch,
    batch_size,  # optimization hypers
    tau,
    gamma,  # advantage estimation
    pol_update=True,
    log_enable=True,
    max_grad_norm=0.5,
    target_policy_smoothing_func=None,
):

    pol_losses = []
    _qf_losses = []
    if log_enable:
        logger.log("Optimizing...")

    for batch in traj.random_batch(batch_size, epoch):

        if (target_policy_smoothing_func is not None):
            qf_losses = lf.td3(
                qfs,
                targ_qfs,
                targ_pol,
                batch,
                gamma,
                continuous=True,
                deterministic=True,
                sampling=1,
                target_policy_smoothing_func=target_policy_smoothing_func)

        else:
            qf_losses = lf.td3(qfs,
                               targ_qfs,
                               targ_pol,
                               batch,
                               gamma,
                               continuous=True,
                               deterministic=True,
                               sampling=1)

        for qf, optim_qf, qf_loss in zip(qfs, optim_qfs, qf_losses):
            optim_qf.zero_grad()
            qf_loss.backward()
            if max_grad_norm is not None:
                torch.nn.utils.clip_grad_norm_(qf.parameters(), max_grad_norm)
            optim_qf.step()

        _qf_losses.append(
            (sum(qf_losses) / len(qf_losses)).detach().cpu().numpy())

        if pol_update:
            pol_loss = lf.ag(pol, qfs[0], batch, no_noise=True)
            optim_pol.zero_grad()
            pol_loss.backward()
            if max_grad_norm is not None:
                torch.nn.utils.clip_grad_norm_(pol.parameters(), max_grad_norm)
            optim_pol.step()

            for p, targ_p in zip(pol.parameters(), targ_pol.parameters()):
                targ_p.detach().copy_((1 - tau) * targ_p.detach() +
                                      tau * p.detach())

            for qf, targ_qf in zip(qfs, targ_qfs):
                for q, targ_q in zip(qf.parameters(), targ_qf.parameters()):
                    targ_q.detach().copy_((1 - tau) * targ_q.detach() +
                                          tau * q.detach())

            pol_losses.append(pol_loss.detach().cpu().numpy())

    if log_enable:
        logger.log("Optimization finished!")
    if pol_update:

        return dict(
            PolLoss=pol_losses,
            QfLoss=_qf_losses,
        )

    else:

        return dict(QfLoss=_qf_losses, )