Esempio n. 1
0
def evaluate(loader, generator):
    ade_outer, fde_outer = [], []
    total_traj = 0
    with torch.no_grad():
        for batch in loader:
            batch = [tensor.cuda() for tensor in batch]
            (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, vgg_list) = batch

            ade, fde = [], []
            total_traj += pred_traj_gt.size(1)

            for _ in range(NUM_SAMPLES):
                pred_traj_fake_rel = generator(obs_traj, obs_traj_rel, vgg_list)
                pred_traj_fake = relative_to_abs(pred_traj_fake_rel, obs_traj[-1, :, 0, :])
                ade.append(displacement_error(pred_traj_fake, pred_traj_gt, mode='raw'))
                fde.append(final_displacement_error(pred_traj_fake[-1], pred_traj_gt[-1], mode='raw'))

            ade_sum = evaluate_helper(ade)
            fde_sum = evaluate_helper(fde)

            ade_outer.append(ade_sum)
            fde_outer.append(fde_sum)
        ade = sum(ade_outer) / (total_traj * PRED_LEN)
        fde = sum(fde_outer) / (total_traj)
        return ade, fde
Esempio n. 2
0
def evaluate(loader, generator, num_samples):
    ade_outer, fde_outer, simulated_output, total_traj, sequences = [], [], [], [], []
    with torch.no_grad():
        for batch in loader:
            if USE_GPU:
                batch = [tensor.cuda() for tensor in batch]
            else:
                batch = [tensor for tensor in batch]
            (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, loss_mask, seq_start_end, obs_ped_speed, pred_ped_speed) = batch

            ade, fde, sim_op = [], [], []
            total_traj.append(pred_traj_gt.size(1))

            for _ in range(num_samples):
                if TEST_METRIC:
                    pred_traj_fake_rel = generator(obs_traj, obs_traj_rel, seq_start_end, obs_ped_speed, pred_ped_speed, pred_traj_gt,
                              TEST_METRIC, SPEED_TO_ADD)
                else:
                    pred_traj_fake_rel = generator(obs_traj, obs_traj_rel, seq_start_end, obs_ped_speed,
                                pred_ped_speed, pred_traj_gt, TEST_METRIC, SPEED_TO_ADD)
                pred_traj_fake = relative_to_abs(pred_traj_fake_rel, obs_traj[-1])
                ade.append(displacement_error(pred_traj_fake, pred_traj_gt, mode='raw'))
                fde.append(final_displacement_error(pred_traj_fake[-1], pred_traj_gt[-1], mode='raw'))
                sim_op.append(pred_traj_fake)

                for _, (start, end) in enumerate(seq_start_end):
                    num_ped = end - start
                    sequences.append(num_ped)

            ade_outer.append(evaluate_helper(torch.stack(ade, dim=1), seq_start_end))
            fde_outer.append(evaluate_helper(torch.stack(fde, dim=1), seq_start_end))
            simulated_output.append(torch.cat(sim_op, dim=0))

        ade = sum(ade_outer) / (sum(total_traj) * PRED_LEN)
        fde = sum(fde_outer) / (sum(total_traj))
        simulated_traj_for_visualization = torch.cat(simulated_output, dim=1)
        sequences = torch.cumsum(torch.stack(sequences, dim=0), dim=0)

        if TEST_METRIC and VERIFY_OUTPUT_SPEED:
            # The speed can be verified for different sequences and this method runs for n number of batches.
            verify_speed(simulated_traj_for_visualization, sequences)

        if ANIMATED_VISUALIZATION_CHECK:
            # Trajectories at User-defined speed for Visualization
            with open('SimulatedTraj.pkl', 'wb') as f:
                pickle.dump(simulated_traj_for_visualization, f, pickle.HIGHEST_PROTOCOL)
            # Sequence list file used for Visualization
            with open('Sequences.pkl', 'wb') as f:
                pickle.dump(sequences, f, pickle.HIGHEST_PROTOCOL)
        return ade, fde
def check_accuracy(loader, generator, discriminator, d_loss_fn):
    d_losses = []
    metrics = {}
    g_l2_losses_abs, g_l2_losses_rel = ([], ) * 2
    disp_error, f_disp_error, mean_speed_disp_error, final_speed_disp_error = [], [], [], []
    total_traj = 0
    loss_mask_sum = 0
    generator.eval()
    with torch.no_grad():
        for batch in loader:
            if USE_GPU:
                batch = [tensor.cuda() for tensor in batch]
            else:
                batch = [tensor for tensor in batch]
            (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, loss_mask,
             seq_start_end, obs_ped_speed, pred_ped_speed) = batch

            pred_traj_fake_rel = generator(obs_traj, obs_traj_rel,
                                           seq_start_end, obs_ped_speed,
                                           pred_ped_speed, pred_traj_gt,
                                           TRAIN_METRIC, SPEED_TO_ADD)
            pred_traj_fake = relative_to_abs(pred_traj_fake_rel, obs_traj[-1])
            loss_mask = loss_mask[:, OBS_LEN:]

            g_l2_loss_abs, g_l2_loss_rel = cal_l2_losses(
                pred_traj_gt, pred_traj_gt_rel, pred_traj_fake,
                pred_traj_fake_rel, loss_mask)
            ade = displacement_error(pred_traj_gt, pred_traj_fake)
            fde = final_displacement_error(pred_traj_gt, pred_traj_fake)

            last_pos = obs_traj[-1]
            traj_for_speed_cal = torch.cat(
                [last_pos.unsqueeze(dim=0), pred_traj_fake], dim=0)
            msae = cal_msae(pred_ped_speed, traj_for_speed_cal)
            fse = cal_fse(pred_ped_speed[-1], pred_traj_fake)

            traj_real = torch.cat([obs_traj, pred_traj_gt], dim=0)
            traj_real_rel = torch.cat([obs_traj_rel, pred_traj_gt_rel], dim=0)
            traj_fake = torch.cat([obs_traj, pred_traj_fake], dim=0)
            traj_fake_rel = torch.cat([obs_traj_rel, pred_traj_fake_rel],
                                      dim=0)
            ped_speed = torch.cat([obs_ped_speed, pred_ped_speed], dim=0)

            scores_fake = discriminator(traj_fake, traj_fake_rel, ped_speed,
                                        seq_start_end)
            scores_real = discriminator(traj_real, traj_real_rel, ped_speed,
                                        seq_start_end)

            d_loss = d_loss_fn(scores_real, scores_fake)
            d_losses.append(d_loss.item())

            g_l2_losses_abs.append(g_l2_loss_abs.item())
            g_l2_losses_rel.append(g_l2_loss_rel.item())
            disp_error.append(ade.item())
            f_disp_error.append(fde.item())
            mean_speed_disp_error.append(msae.item())
            final_speed_disp_error.append(fse.item())

            loss_mask_sum += torch.numel(loss_mask.data)
            total_traj += pred_traj_gt.size(1)
            if total_traj >= NUM_SAMPLE_CHECK:
                break

    metrics['d_loss'] = sum(d_losses) / len(d_losses)
    metrics['g_l2_loss_abs'] = sum(g_l2_losses_abs) / loss_mask_sum
    metrics['g_l2_loss_rel'] = sum(g_l2_losses_rel) / loss_mask_sum
    metrics['ade'] = sum(disp_error) / (total_traj * PRED_LEN)
    metrics['fde'] = sum(f_disp_error) / total_traj
    metrics['msae'] = sum(mean_speed_disp_error) / (total_traj * PRED_LEN)
    metrics['fse'] = sum(final_speed_disp_error) / total_traj

    generator.train()
    return metrics
Esempio n. 4
0
def eval(epoch, data_loader=test_loader):
    t = time.time()
    loss_eval = 0.0
    model.eval()
    if args.mode == 'eval':
        model.load_state_dict(torch.load(best_model_file))
    else:
        model.load_state_dict(torch.load(model_file))

    if args.encoder == 'nmp':
        loss_eval = nmp_iter_one_epoch(data_loader,
                                       epoch=epoch,
                                       is_training=False)

        v_pred_centers = np.load(
            os.path.join(save_folder, 'test_v_pred_centers.npy'))
        v_target_centers = np.load(
            os.path.join(save_folder, 'test_v_target_centers.npy'))
        h_pred_centers = np.load(
            os.path.join(save_folder, 'test_h_pred_centers.npy'))
        h_target_centers = np.load(
            os.path.join(save_folder, 'test_h_target_centers.npy'))
        m_pred_centers = np.load(
            os.path.join(save_folder, 'test_m_pred_centers.npy'))
        m_target_centers = np.load(
            os.path.join(save_folder, 'test_m_target_centers.npy'))

        v_displacement_error = displacement_error(v_pred_centers,
                                                  v_target_centers)
        h_displacement_error = displacement_error(h_pred_centers,
                                                  h_target_centers)
        m_displacement_error = displacement_error(m_pred_centers,
                                                  m_target_centers)

        v_f_displacement_error = final_displacement_error(
            v_pred_centers, v_target_centers)
        h_f_displacement_error = final_displacement_error(
            h_pred_centers, h_target_centers)
        m_f_displacement_error = final_displacement_error(
            m_pred_centers, m_target_centers)

        print(
            'Epoch: {:04d}'.format(epoch),
            'V_Dis_Error: {:.04f}/{:.04f}'.format(v_displacement_error,
                                                  v_f_displacement_error),
            'H_Dis_Error: {:.04f}/{:.04f}'.format(h_displacement_error,
                                                  h_f_displacement_error),
            'M_Dis_Error: {:.04f}/{:.04f}'.format(m_displacement_error,
                                                  m_f_displacement_error))
        print('Epoch: {:04d}'.format(epoch),
              'V_Dis_Error: {:.04f}/{:.04f}'.format(v_displacement_error,
                                                    v_f_displacement_error),
              'H_Dis_Error: {:.04f}/{:.04f}'.format(h_displacement_error,
                                                    h_f_displacement_error),
              'M_Dis_Error: {:.04f}/{:.04f}'.format(m_displacement_error,
                                                    m_f_displacement_error),
              file=log)

        test_pred_centers = np.load(
            os.path.join(save_folder, 'test_pred_centers.npy'))
        test_target_centers = np.load(
            os.path.join(save_folder, 'test_target_centers.npy'))
        test_moving_masks = np.load(
            os.path.join(save_folder, 'test_moving_masks.npy'))
        test_v_moving_masks = np.load(
            os.path.join(save_folder, 'test_v_moving_masks.npy'))
        test_h_moving_masks = np.load(
            os.path.join(save_folder, 'test_h_moving_masks.npy'))
        test_m_moving_masks = np.load(
            os.path.join(save_folder, 'test_m_moving_masks.npy'))
        nmp_v_displacement_error = displacement_error(test_pred_centers,
                                                      test_target_centers,
                                                      test_v_moving_masks)
        nmp_h_displacement_error = displacement_error(test_pred_centers,
                                                      test_target_centers,
                                                      test_h_moving_masks)
        nmp_m_displacement_error = displacement_error(test_pred_centers,
                                                      test_target_centers,
                                                      test_m_moving_masks)

        nmp_v_f_displacement_error = final_displacement_error(
            test_pred_centers, test_target_centers, test_v_moving_masks)
        nmp_h_f_displacement_error = final_displacement_error(
            test_pred_centers, test_target_centers, test_h_moving_masks)
        nmp_m_f_displacement_error = final_displacement_error(
            test_pred_centers, test_target_centers, test_m_moving_masks)

        print(
            'Epoch: {:04d}'.format(epoch),
            'V_Dis_Error: {:.04f}/{:.04f}'.format(nmp_v_displacement_error,
                                                  nmp_v_f_displacement_error),
            'H_Dis_Error: {:.04f}/{:.04f}'.format(nmp_h_displacement_error,
                                                  nmp_h_f_displacement_error),
            'M_Dis_Error: {:.04f}/{:.04f}'.format(nmp_m_displacement_error,
                                                  nmp_m_f_displacement_error),
            'time: {:.04f}'.format(time.time() - t))
        print(
            'Epoch: {:04d}'.format(epoch),
            'V_Dis_Error: {:.04f}/{:.04f}'.format(nmp_v_displacement_error,
                                                  nmp_v_f_displacement_error),
            'H_Dis_Error: {:.04f}/{:.04f}'.format(nmp_h_displacement_error,
                                                  nmp_h_f_displacement_error),
            'M_Dis_Error: {:.04f}/{:.04f}'.format(nmp_m_displacement_error,
                                                  nmp_m_f_displacement_error),
            'time: {:.04f}'.format(time.time() - t),
            file=log)
    else:
        loss_eval = iter_one_epoch(data_loader, epoch=epoch, is_training=False)
        test_pred_centers = np.load(
            os.path.join(save_folder, 'test_pred_centers.npy'))
        test_target_centers = np.load(
            os.path.join(save_folder, 'test_target_centers.npy'))
        test_moving_masks = np.load(
            os.path.join(save_folder, 'test_moving_masks.npy'))
        avg_displacement_error = displacement_error(test_pred_centers,
                                                    test_target_centers,
                                                    test_moving_masks)
        # avg_displacement_error = displacement_error(test_pred_centers, test_target_centers)

        print('Epoch: {:04d}'.format(epoch),
              'Loss_Eval: {:.04f}'.format(loss_eval),
              'Displace_Error: {:.04f}'.format(avg_displacement_error),
              'time: {:.04f}'.format(time.time() - t))
        print('Epoch: {:04d}'.format(epoch),
              'Loss_Eval: {:.04f}'.format(loss_eval),
              'Displace_Error: {:.04f}'.format(avg_displacement_error),
              'time: {:.04f}'.format(time.time() - t),
              file=log)
    return
Esempio n. 5
0
def train(epoch, min_val_loss):
    t = time.time()
    loss_train = 0.0
    loss_val = 0.0
    v_displacement_error = 0.0
    h_displacement_error = 0.0
    m_displacement_error = 0.0

    if args.encoder == 'nmp':
        iter_fn = nmp_iter_one_epoch
    else:
        iter_fn = iter_one_epoch

    model.train()
    scheduler.step()

    loss_train = iter_fn(train_loader, epoch, is_training=True)
    model.eval()
    loss_val = iter_fn(val_loader, epoch, is_training=False)
    model.eval()
    loss_test = iter_fn(test_loader, epoch, is_training=False)

    if args.encoder == 'nmp':
        v_pred_centers = np.load(
            os.path.join(save_folder, 'test_v_pred_centers.npy'))
        v_target_centers = np.load(
            os.path.join(save_folder, 'test_v_target_centers.npy'))
        h_pred_centers = np.load(
            os.path.join(save_folder, 'test_h_pred_centers.npy'))
        h_target_centers = np.load(
            os.path.join(save_folder, 'test_h_target_centers.npy'))
        m_pred_centers = np.load(
            os.path.join(save_folder, 'test_m_pred_centers.npy'))
        m_target_centers = np.load(
            os.path.join(save_folder, 'test_m_target_centers.npy'))

        v_displacement_error = displacement_error(v_pred_centers,
                                                  v_target_centers)
        h_displacement_error = displacement_error(h_pred_centers,
                                                  h_target_centers)
        m_displacement_error = displacement_error(m_pred_centers,
                                                  m_target_centers)

        v_f_displacement_error = final_displacement_error(
            v_pred_centers, v_target_centers)
        h_f_displacement_error = final_displacement_error(
            h_pred_centers, h_target_centers)
        m_f_displacement_error = final_displacement_error(
            m_pred_centers, m_target_centers)

        print(
            'Epoch: {:04d}'.format(epoch),
            'V_Dis_Error: {:.04f}/{:.04f}'.format(v_displacement_error,
                                                  v_f_displacement_error),
            'H_Dis_Error: {:.04f}/{:.04f}'.format(h_displacement_error,
                                                  h_f_displacement_error),
            'M_Dis_Error: {:.04f}/{:.04f}'.format(m_displacement_error,
                                                  m_f_displacement_error))
        print('Epoch: {:04d}'.format(epoch),
              'V_Dis_Error: {:.04f}/{:.04f}'.format(v_displacement_error,
                                                    v_f_displacement_error),
              'H_Dis_Error: {:.04f}/{:.04f}'.format(h_displacement_error,
                                                    h_f_displacement_error),
              'M_Dis_Error: {:.04f}/{:.04f}'.format(m_displacement_error,
                                                    m_f_displacement_error),
              file=log)

        test_pred_centers = np.load(
            os.path.join(save_folder, 'test_pred_centers.npy'))
        test_target_centers = np.load(
            os.path.join(save_folder, 'test_target_centers.npy'))
        test_moving_masks = np.load(
            os.path.join(save_folder, 'test_moving_masks.npy'))
        test_v_moving_masks = np.load(
            os.path.join(save_folder, 'test_v_moving_masks.npy'))
        test_h_moving_masks = np.load(
            os.path.join(save_folder, 'test_h_moving_masks.npy'))
        test_m_moving_masks = np.load(
            os.path.join(save_folder, 'test_m_moving_masks.npy'))
        nmp_v_displacement_error = displacement_error(test_pred_centers,
                                                      test_target_centers,
                                                      test_v_moving_masks)
        nmp_h_displacement_error = displacement_error(test_pred_centers,
                                                      test_target_centers,
                                                      test_h_moving_masks)
        nmp_m_displacement_error = displacement_error(test_pred_centers,
                                                      test_target_centers,
                                                      test_m_moving_masks)

        nmp_v_f_displacement_error = final_displacement_error(
            test_pred_centers, test_target_centers, test_v_moving_masks)
        nmp_h_f_displacement_error = final_displacement_error(
            test_pred_centers, test_target_centers, test_h_moving_masks)
        nmp_m_f_displacement_error = final_displacement_error(
            test_pred_centers, test_target_centers, test_m_moving_masks)

        print(
            'Epoch: {:04d}'.format(epoch),
            'Loss_Train: {:.04f}'.format(loss_train),
            'Loss_Val: {:.04f}'.format(loss_val),
            'Loss_Test: {:.04f}'.format(loss_test),
            'V_Dis_Error: {:.04f}/{:.04f}'.format(nmp_v_displacement_error,
                                                  nmp_v_f_displacement_error),
            'H_Dis_Error: {:.04f}/{:.04f}'.format(nmp_h_displacement_error,
                                                  nmp_h_f_displacement_error),
            'M_Dis_Error: {:.04f}/{:.04f}'.format(nmp_m_displacement_error,
                                                  nmp_m_f_displacement_error),
            'time: {:.04f}'.format(time.time() - t))
        print(
            'Epoch: {:04d}'.format(epoch),
            'Loss_Train: {:.04f}'.format(loss_train),
            'Loss_Val: {:.04f}'.format(loss_val),
            'Loss_Test: {:.04f}'.format(loss_test),
            'V_Dis_Error: {:.04f}/{:.04f}'.format(nmp_v_displacement_error,
                                                  nmp_v_f_displacement_error),
            'H_Dis_Error: {:.04f}/{:.04f}'.format(nmp_h_displacement_error,
                                                  nmp_h_f_displacement_error),
            'M_Dis_Error: {:.04f}/{:.04f}'.format(nmp_m_displacement_error,
                                                  nmp_m_f_displacement_error),
            'time: {:.04f}'.format(time.time() - t),
            file=log)
    else:
        test_pred_centers = np.load(
            os.path.join(save_folder, 'test_pred_centers.npy'))
        test_target_centers = np.load(
            os.path.join(save_folder, 'test_target_centers.npy'))
        test_moving_masks = np.load(
            os.path.join(save_folder, 'test_moving_masks.npy'))
        avg_displacement_error = displacement_error(test_pred_centers,
                                                    test_target_centers,
                                                    test_moving_masks)

        print('Epoch: {:04d}'.format(epoch),
              'Loss_Eval: {:.04f}'.format(loss_test),
              'Displace_Error: {:.04f}'.format(avg_displacement_error),
              'time: {:.04f}'.format(time.time() - t))
        print('Epoch: {:04d}'.format(epoch),
              'Loss_Eval: {:.04f}'.format(loss_test),
              'Displace_Error: {:.04f}'.format(avg_displacement_error),
              'time: {:.04f}'.format(time.time() - t),
              file=log)
    log.flush()
    torch.save(model.state_dict(), model_file)

    if args.save_folder and loss_val < min_val_loss:
        torch.save(model.state_dict(), best_model_file)
    return loss_val
Esempio n. 6
0
def cal_ade_fde(pred_traj_gt, pred_traj_fake):
    ade = displacement_error(pred_traj_fake, pred_traj_gt)
    fde = final_displacement_error(pred_traj_fake[-1], pred_traj_gt[-1])
    return ade, fde
Esempio n. 7
0
def check_accuracy(loader, generator, discriminator, d_loss_fn, limit=False):

    d_losses = []  #
    metrics = {}
    g_l2_losses_abs, g_l2_losses_rel = ([], ) * 2
    disp_error = []  # ADE FDE
    f_disp_error = []
    total_traj = 0

    mask_sum = 0
    generator.eval()
    with torch.no_grad():  #
        for batch in loader:
            batch = [tensor.cuda() for tensor in batch]
            # (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, vgg_list) = batch
            # (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel) = batch
            (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, n_l, l_m,
             V_obs, A_obs, V_pre, A_pre, vgg_list) = batch
            # pred_traj_fake_rel = generator(obs_traj, obs_traj_rel, vgg_list)
            pred_traj_fake_rel = generator(obs_traj, obs_traj_rel, V_obs,
                                           A_obs, vgg_list)  # T V C
            pred_traj_fake = relative_to_abs(pred_traj_fake_rel,
                                             obs_traj[0, :, :,
                                                      -1])  # T V C——V C

            g_l2_loss_abs = l2_loss(pred_traj_fake, pred_traj_gt, mode='sum')
            g_l2_loss_rel = l2_loss(pred_traj_fake_rel,
                                    pred_traj_gt_rel,
                                    mode='sum')

            ade = displacement_error(pred_traj_fake, pred_traj_gt)  # TVC NVCT
            fde = final_displacement_error(pred_traj_fake[-1],
                                           pred_traj_gt[0, :, :, -1])  # VC  VC

            traj_real = torch.cat([obs_traj[:, :, 0, :], pred_traj_gt], dim=0)
            traj_real_rel = torch.cat(
                [obs_traj_rel[:, :, 0, :], pred_traj_gt_rel], dim=0)
            traj_fake = torch.cat([obs_traj[:, :, 0, :], pred_traj_fake],
                                  dim=0)
            traj_fake_rel = torch.cat(
                [obs_traj_rel[:, :, 0, :], pred_traj_fake_rel], dim=0)

            scores_fake = discriminator(traj_fake, traj_fake_rel)
            scores_real = discriminator(traj_real, traj_real_rel)

            d_loss = d_loss_fn(scores_real, scores_fake)
            d_losses.append(d_loss.item())

            g_l2_losses_abs.append(g_l2_loss_abs.item())
            g_l2_losses_rel.append(g_l2_loss_rel.item())
            disp_error.append(ade.item())
            f_disp_error.append(fde.item())

            mask_sum += (pred_traj_gt.size(1) * PRED_LEN)
            total_traj += pred_traj_gt.size(1)
            if limit and total_traj >= NUM_SAMPLES_CHECK:
                break

    metrics['d_loss'] = sum(d_losses) / len(d_losses)
    metrics['g_l2_loss_abs'] = sum(g_l2_losses_abs) / mask_sum
    metrics['g_l2_loss_rel'] = sum(g_l2_losses_rel) / mask_sum

    metrics['ade'] = sum(disp_error) / (total_traj * PRED_LEN)
    metrics['fde'] = sum(f_disp_error) / total_traj
    generator.train()
    return metrics
Esempio n. 8
0
def check_accuracy(loader, generator, discriminator, d_loss_fn, speed_regressor):
    d_losses = []
    metrics = {}
    g_l2_losses_abs, g_l2_losses_rel = ([],) * 2
    disp_error, f_disp_error, mean_speed_disp_error, final_speed_disp_error = [], [], [], []
    total_traj = 0
    loss_mask_sum = 0
    generator.eval()
    with torch.no_grad():
        for batch in loader:
            if USE_GPU:
                batch = [tensor.cuda() for tensor in batch]
            else:
                batch = [tensor for tensor in batch]
            if MULTI_CONDITIONAL_MODEL:
                (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, loss_mask, seq_start_end, obs_ped_speed,
                 pred_ped_speed, obs_label, pred_label, obs_obj_rel_speed) = batch
            else:
                (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, loss_mask, seq_start_end, obs_ped_speed,
                 pred_ped_speed, obs_obj_rel_speed) = batch

            if MULTI_CONDITIONAL_MODEL:
                pred_traj_fake_rel, final_enc_h = generator(obs_traj, obs_traj_rel, seq_start_end, obs_ped_speed, pred_ped_speed,
                                  pred_traj_gt, TRAIN_METRIC, None, obs_obj_rel_speed, obs_label=obs_label, pred_label=pred_label)
            else:
                pred_traj_fake_rel, final_enc_h = generator(obs_traj, obs_traj_rel, seq_start_end, obs_ped_speed, pred_ped_speed,
                                      pred_traj_gt, TRAIN_METRIC, None, obs_obj_rel_speed, obs_label=None, pred_label=None)

            fake_ped_speed = speed_regressor(obs_ped_speed, final_enc_h)

            pred_traj_fake = relative_to_abs(pred_traj_fake_rel, obs_traj[-1])
            loss_mask = loss_mask[:, OBS_LEN:]

            g_l2_loss_abs, g_l2_loss_rel = cal_l2_losses(
                pred_traj_gt, pred_traj_gt_rel, pred_traj_fake,
                pred_traj_fake_rel, loss_mask
            )

            abs_speed_los = cal_mae_speed_loss(pred_ped_speed, fake_ped_speed)
            ade = displacement_error(pred_traj_gt, pred_traj_fake)
            fde = final_displacement_error(pred_traj_gt, pred_traj_fake)

            traj_real = torch.cat([obs_traj, pred_traj_gt], dim=0)
            traj_real_rel = torch.cat([obs_traj_rel, pred_traj_gt_rel], dim=0)
            traj_fake = torch.cat([obs_traj, pred_traj_fake], dim=0)
            traj_fake_rel = torch.cat([obs_traj_rel, pred_traj_fake_rel], dim=0)
            ped_speed = torch.cat([obs_ped_speed, pred_ped_speed], dim=0)
            if MULTI_CONDITIONAL_MODEL:
                label_info = torch.cat([obs_label, pred_label], dim=0)
                scores_fake = discriminator(traj_fake, traj_fake_rel, ped_speed, label=label_info)
                scores_real = discriminator(traj_real, traj_real_rel, ped_speed, label=label_info)
            else:
                scores_fake = discriminator(traj_fake, traj_fake_rel, ped_speed, label=None)
                scores_real = discriminator(traj_real, traj_real_rel, ped_speed, label=None)

            d_loss = d_loss_fn(scores_real, scores_fake)
            d_losses.append(d_loss.item())

            g_l2_losses_abs.append(g_l2_loss_abs.item())
            g_l2_losses_rel.append(g_l2_loss_rel.item())
            disp_error.append(ade.item())
            f_disp_error.append(fde.item())

            loss_mask_sum += torch.numel(loss_mask.data)
            total_traj += pred_traj_gt.size(1)
            if total_traj >= NUM_SAMPLE_CHECK:
                break

    metrics['d_loss'] = sum(d_losses) / len(d_losses)
    metrics['g_l2_loss_abs'] = sum(g_l2_losses_abs) / loss_mask_sum
    metrics['g_l2_loss_rel'] = sum(g_l2_losses_rel) / loss_mask_sum
    metrics['ade'] = sum(disp_error) / (total_traj * PRED_LEN)
    metrics['fde'] = sum(f_disp_error) / total_traj

    generator.train()
    return metrics
def evaluate(loader, generator, num_samples, speed_regressor):
    ade_outer, fde_outer, simulated_output, total_traj, sequences, labels, observed_traj = [], [], [], [], [], [], []
    with torch.no_grad():
        for batch in loader:
            if USE_GPU:
                batch = [tensor.cuda() for tensor in batch]
            else:
                batch = [tensor for tensor in batch]
            if MULTI_CONDITIONAL_MODEL:
                (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel,
                 loss_mask, seq_start_end, obs_ped_speed, pred_ped_speed,
                 obs_label, pred_label, obs_obj_rel_speed) = batch
            else:
                (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel,
                 loss_mask, seq_start_end, obs_ped_speed, pred_ped_speed,
                 obs_obj_rel_speed) = batch

            ade, fde, traj_op, traj_obs = [], [], [], []
            total_traj.append(pred_traj_gt.size(1))
            sequences.append(seq_start_end)
            if MULTI_CONDITIONAL_MODEL:
                labels.append(torch.cat([obs_label, pred_label], dim=0))

            for _ in range(num_samples):
                if TEST_METRIC == 1:  # USED DURING PREDICTION ENVIRONMENT
                    if MULTI_CONDITIONAL_MODEL:
                        _, final_enc_h = generator(obs_traj,
                                                   obs_traj_rel,
                                                   seq_start_end,
                                                   obs_ped_speed,
                                                   pred_ped_speed,
                                                   pred_traj_gt,
                                                   0,
                                                   None,
                                                   obs_obj_rel_speed,
                                                   obs_label=obs_label,
                                                   pred_label=pred_label)
                        fake_speed = speed_regressor(obs_ped_speed,
                                                     final_enc_h)
                        pred_traj_fake_rel, _ = generator(
                            obs_traj,
                            obs_traj_rel,
                            seq_start_end,
                            obs_ped_speed,
                            pred_ped_speed,
                            pred_traj_gt,
                            TEST_METRIC,
                            fake_speed,
                            obs_obj_rel_speed,
                            obs_label=obs_label,
                            pred_label=pred_label)
                    else:
                        _, final_enc_h = generator(obs_traj,
                                                   obs_traj_rel,
                                                   seq_start_end,
                                                   obs_ped_speed,
                                                   pred_ped_speed,
                                                   pred_traj_gt,
                                                   0,
                                                   None,
                                                   obs_obj_rel_speed,
                                                   obs_label=None,
                                                   pred_label=None)
                        fake_speed = speed_regressor(obs_ped_speed,
                                                     final_enc_h)
                        pred_traj_fake_rel, _ = generator(obs_traj,
                                                          obs_traj_rel,
                                                          seq_start_end,
                                                          obs_ped_speed,
                                                          pred_ped_speed,
                                                          pred_traj_gt,
                                                          TEST_METRIC,
                                                          fake_speed,
                                                          obs_obj_rel_speed,
                                                          obs_label=None,
                                                          pred_label=None)
                elif TEST_METRIC == 2:  # Used during Simulation environment
                    if MULTI_CONDITIONAL_MODEL:
                        pred_traj_fake_rel, _ = generator(
                            obs_traj,
                            obs_traj_rel,
                            seq_start_end,
                            obs_ped_speed,
                            pred_ped_speed,
                            pred_traj_gt,
                            TEST_METRIC,
                            None,
                            obs_obj_rel_speed,
                            obs_label=obs_label,
                            pred_label=pred_label)
                    else:
                        pred_traj_fake_rel, _ = generator(obs_traj,
                                                          obs_traj_rel,
                                                          seq_start_end,
                                                          obs_ped_speed,
                                                          pred_ped_speed,
                                                          pred_traj_gt,
                                                          TEST_METRIC,
                                                          None,
                                                          obs_obj_rel_speed,
                                                          obs_label=None,
                                                          pred_label=None)

                pred_traj_fake = relative_to_abs(pred_traj_fake_rel,
                                                 obs_traj[-1])
                ade.append(
                    displacement_error(pred_traj_fake,
                                       pred_traj_gt,
                                       mode='raw'))
                fde.append(
                    final_displacement_error(pred_traj_fake[-1],
                                             pred_traj_gt[-1],
                                             mode='raw'))
                traj_op.append(pred_traj_fake.unsqueeze(dim=0))
                traj_obs.append(obs_traj.unsqueeze(dim=0))

            best_traj, min_ade_error = evaluate_helper(
                torch.stack(ade, dim=1), torch.cat(traj_op, dim=0),
                seq_start_end)
            staked_obs = torch.cat(traj_obs, dim=0)
            obs = staked_obs[0]
            observed_traj.append(obs)
            _, min_fde_error = evaluate_helper(torch.stack(fde, dim=1),
                                               torch.cat(traj_op, dim=0),
                                               seq_start_end)
            ade_outer.append(min_ade_error)
            fde_outer.append(min_fde_error)
            simulated_output.append(best_traj)

        ade = sum(ade_outer) / (sum(total_traj) * PRED_LEN)
        fde = sum(fde_outer) / (sum(total_traj))
        simulated_traj = torch.cat(simulated_output, dim=1)
        total_obs = torch.cat(observed_traj, dim=1).permute(1, 0, 2)
        if MULTI_CONDITIONAL_MODEL:
            all_labels = torch.cat(labels, dim=1)
        last_items_in_sequences = []
        curr_sequences = []
        i = 0
        for sequence_list in sequences:
            last_sequence = sequence_list[-1]
            if i > 0:
                last_items_sum = sum(last_items_in_sequences)
                curr_sequences.append(last_items_sum + sequence_list)
            last_items_in_sequences.append(last_sequence[1])
            if i == 0:
                curr_sequences.append(sequence_list)
                i += 1
                continue

        sequences = torch.cat(curr_sequences, dim=0)
        colpercent = collisionPercentage(simulated_traj, sequences)
        print('Collision Percentage: ', colpercent * 100)

        # The user defined speed is verified by computing inverse sigmoid function on the output speed of the model.
        if TEST_METRIC == 2:
            if SINGLE_CONDITIONAL_MODEL:
                verify_speed(simulated_traj, sequences, labels=None)
            else:
                verify_speed(simulated_traj, sequences, labels=all_labels)

        return ade, fde, colpercent * 100