def test(args, epch, test_loader, model, warmup, writer): test_loss = 0 kld_loss = 0 nll_loss = 0 total_traj = 0 ade = 0 fde = 0 model.eval() with torch.no_grad(): for batch_idx, data in enumerate(test_loader): data = [tensor.cuda() for tensor in data] (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, non_linear_ped, loss_mask, seq_start_end) = data kld, nll, h = model(obs_traj_rel) loss = (warmup[epch - 1] * kld) + nll test_loss += loss.item() kld_loss += kld.item() nll_loss += nll.item() if batch_idx % args.print_every_batch == 0: writer.add_scalar('test/loss_items', loss.item(), epch) if not DEBUG else None # predict trajectories from latest h samples_rel = model.sample(args.pred_len, h) samples = relative_to_abs(samples_rel, obs_traj[-1]) total_traj += samples.shape[1] # num_seqs ade += average_displacement_error(samples, pred_traj_gt) fde += final_displacement_error(samples[-1, :, :], pred_traj_gt[-1, :, :]) mean_loss = loss / len(test_loader.dataset) writer.add_scalar('test/Loss', mean_loss, epch) if not DEBUG else None mean_kld_loss = kld_loss / len(test_loader.dataset) mean_nll_loss = nll_loss / len(test_loader.dataset) writer.add_scalar('test/KLD_loss', mean_kld_loss, epch) if not DEBUG else None writer.add_scalar('test/NLL_loss', mean_nll_loss, epch) if not DEBUG else None # ADE ade_val = ade / (total_traj * args.pred_len) writer.add_scalar('test/ADE', ade_val, epch) if not DEBUG else None # FDE fde_val = fde / total_traj writer.add_scalar('test/FDE', fde_val, epch) if not DEBUG else None # plotting if DEBUG: obs = obs_traj.cpu().numpy() pred = samples.cpu().numpy() pred_gt = pred_traj_gt.cpu().numpy() plot_traj(obs, pred_gt, pred, seq_start_end, writer, epch)
def test(args, epch, test_loader, model, warmup, writer): test_loss = 0 kld_loss = 0 nll_loss = 0 cross_entropy_loss = 0 total_traj = 0 ade = 0 fde = 0 model.eval() with torch.no_grad(): for batch_idx, data in enumerate(test_loader): data = [tensor.cuda() for tensor in data] (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_rel_gt, obs_goals, pred_goals_gt, seq_start_end) = data # goals one-hot encoding obs_goals_ohe = to_goals_one_hot(obs_goals, args.g_dim).cuda() # adj matrix for current batch if args.adjacency_type == 0: adj_out = compute_adjs(args, seq_start_end).cuda() elif args.adjacency_type == 1: adj_out = compute_adjs_distsim(args, seq_start_end, obs_traj.detach().cpu(), pred_traj_gt.detach().cpu()).cuda() elif args.adjacency_type == 2: adj_out = compute_adjs_knnsim(args, seq_start_end, obs_traj.detach().cpu(), pred_traj_gt.detach().cpu()).cuda() kld, nll, ce, h = model(obs_traj, obs_traj_rel, obs_goals_ohe, seq_start_end, adj_out) loss = nll + (warmup[epch - 1] * kld) + (ce * args.CE_weight) test_loss += loss.item() kld_loss += kld.item() nll_loss += nll.item() cross_entropy_loss += ce.item() if batch_idx % args.print_every_batch == 0: writer.add_scalar('test/loss_items', loss.item(), epch) if not DEBUG else None # predict trajectories from latest h; samples_rel shape=(pred_seq_len, n_agents, batch, xy) samples_rel = model.sample(args.pred_len, h, obs_traj[-1], obs_goals_ohe[-1], seq_start_end) samples = relative_to_abs(samples_rel, obs_traj[-1]) total_traj += samples.shape[1] # num_seqs ade += average_displacement_error(samples, pred_traj_gt) fde += final_displacement_error(samples[-1, :, :], pred_traj_gt[-1, :, :]) mean_loss = loss / len(test_loader.dataset) writer.add_scalar('test/Loss', mean_loss, epch) if not DEBUG else None mean_kld_loss = kld_loss / len(test_loader.dataset) mean_nll_loss = nll_loss / len(test_loader.dataset) mean_cross_entropy_loss = cross_entropy_loss / len(test_loader.dataset) writer.add_scalar('test/KLD_loss', mean_kld_loss, epch) if not DEBUG else None writer.add_scalar('test/NLL_loss', mean_nll_loss, epch) if not DEBUG else None writer.add_scalar('test/CE_loss', mean_cross_entropy_loss, epch) if not DEBUG else None # ADE ade_val = ade / (total_traj * args.pred_len) writer.add_scalar('test/ADE', ade_val, epch) if not DEBUG else None # FDE fde_val = fde / total_traj writer.add_scalar('test/FDE', fde_val, epch) if not DEBUG else None # plotting if not DEBUG: obs = obs_traj.cpu().numpy() pred = samples.cpu().numpy() pred_gt = pred_traj_gt.cpu().numpy() plot_traj(obs, pred_gt, pred, seq_start_end, writer, epch)