def forward_g(scene,
              his_traj,
              targets,
              model_g,
              model_d,
              optimizer,
              scheduler,
              omega=1.0,
              epsilon=1.0):

    model_g.train()
    preds, conf, context, z_mean, z_var = model_g(scene, his_traj)
    traj_fake = utils.multi2single(preds, targets, conf, mode='best')
    score_fake = model_d(traj_fake.permute(1, 0, 2), context)
    # 判别loss + nll_loss + vae_loss
    g_loss = utils.g_loss(score_fake)
    # nll_loss = utils.pytorch_neg_multi_log_likelihood_batch(targets, preds, conf)
    vae_loss, ade_loss = cvae.loss_cvae(targets, preds, conf, z_mean, z_var)
    #     loss = g_loss + nll_loss * omega + vae_loss * epsilon
    # loss = g_loss + nll_loss * omega + vae_loss * epsilon
    loss = g_loss + vae_loss * epsilon
    scheduler.step()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss, vae_loss, ade_loss, preds, conf
Ejemplo n.º 2
0
def forward_g(scene,
              his_traj,
              targets,
              model_g,
              model_d,
              optimizer,
              scheduler,
              omega=1.0,
              epsilon=1.0):

    model_g.train()
    preds, conf, context = model_g(scene, his_traj)
    traj_fake = utils.multi2single(preds, targets, conf, mode='best')
    score_fake = model_d(traj_fake.permute(1, 0, 2), context)
    # 判别loss + nll_loss + ade_loss
    g_loss = utils.g_loss(score_fake)
    # nll_loss = utils.pytorch_neg_multi_log_likelihood_batch(targets, preds, conf)
    min_l2_loss = utils._average_displacement_error(targets,
                                                    preds,
                                                    conf,
                                                    mode='best')
    #     loss = g_loss + nll_loss * omega + l2_loss * epsilon
    nll_loss = 0.0
    loss = g_loss + nll_loss * omega + min_l2_loss * epsilon
    scheduler.step()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss, nll_loss, min_l2_loss, preds, conf
def forward_d(scene, his_traj, targets, model_g, model_d, optimizer,
              scheduler):

    model_d.train()
    preds, confidences, context, _, _ = model_g(scene, his_traj)
    traj_fake = utils.multi2single(preds, targets, confidences, mode='best')
    score_fake = model_d(traj_fake.permute(1, 0, 2), context)
    score_real = model_d(targets.permute(1, 0, 2), context)
    loss = utils.d_loss(score_real, score_fake)
    scheduler.step()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss
                           mininterval=5.)
 for j in valid_progress_bar:
     try:
         data_valid = next(tr_it_valid)
     except StopIteration:
         tr_it_valid = iter(valid_dataloader)
         data_valid = next(tr_it_valid)
     scene_valid = data_valid[0].to(device)
     scene_valid = scene_valid.permute(0, 3, 2, 1)
     his_traj_valid = data_valid[3].to(device)
     his_traj_valid = his_traj_valid.permute(1, 0, 2)
     targets_valid = data_valid[4].to(device)
     pred_pixel, conf, context, z_mean, z_var = generator(
         scene_valid.float(), his_traj_valid.float())
     traj_fake_valid = utils.multi2single(pred_pixel,
                                          targets_valid.float(),
                                          conf,
                                          mode='best')
     score_fake = discriminator(
         traj_fake_valid.permute(1, 0, 2), context)
     g_loss_valid = utils.g_loss(score_fake)
     # nll_loss_valid = utils.pytorch_neg_multi_log_likelihood_batch(targets_valid, pred_pixel, conf)
     vae_loss_valid, min_l2_loss_valid = cvae.loss_cvae(
         targets_valid, pred_pixel, conf, z_mean, z_var)
     #     loss = g_loss + nll_loss * omega + l2_loss * epsilon
     # valid_loss = g_loss_valid + nll_loss_valid * omega + vae_loss_valid * epsilon
     valid_loss = g_loss_valid + vae_loss_valid * epsilon
     # camera frame to world frame(meter)
     pred = torch.zeros_like(pred_pixel)
     for batch_index in range(pred_pixel.shape[0]):
         for modality in range(pred_pixel.shape[1]):
             for pos_index in range(pred_pixel.shape[2]):