def test(epoch, test_loader, model, writer, beta_vals): """Use test data to evaluate likelihood of the model""" mean_kld_loss, mean_nll_loss, mean_ade_loss, mean_kld_hm = 0, 0, 0, 0 disp_error, disp_error_l, disp_error_nl = ([], ) * 3 f_disp_error, f_disp_error_l, f_disp_error_nl = ([], ) * 3 total_traj, total_traj_l, total_traj_nl = 0, 0, 0 metrics = {} model.eval() beta = beta_vals[epoch] with torch.no_grad(): for i, batch in enumerate(test_loader): (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, seq_start_end, maps, dnames) = batch if args.adj_type == 0: adj_out = compute_adjs(args, seq_start_end) elif args.adj_type == 1: adj_out = compute_adjs_distsim(args, seq_start_end, obs_traj, pred_traj_gt) elif args.adj_type == 2: adj_out = compute_adjs_knnsim(args, seq_start_end, obs_traj, pred_traj_gt) kld_loss, nll_loss, kld_hm, h = model(obs_traj_rel.cuda(), adj_out.cuda(), seq_start_end.cuda(), obs_traj[0], maps[:args.obs_len], epoch) mean_kld_loss += beta * kld_loss.item() mean_nll_loss += nll_loss.item() mean_kld_hm += kld_hm.item() v_losses = [] if args.v_loss: h_samples = torch.cat(args.k_vloss * [h], 1) pred_traj_rel = model.sample(args.pred_len, seq_start_end.cuda(), True, maps[args.obs_len - 1:], obs_traj[-1], dnames, h_samples) pred_traj_rel = torch.stack( torch.chunk(pred_traj_rel, args.k_vloss, dim=1)) for k in range(0, args.k_vloss): pred_traj_abs = relative_to_abs(pred_traj_rel[k], obs_traj[-1]) ade_loss = displacement_error( pred_traj_abs, pred_traj_gt) / obs_traj_rel.size(1) v_losses.append(ade_loss) ade_min = min(v_losses).cuda() mean_ade_loss += ade_min.item() if i % args.print_every == 0: pred_traj_sampled_rel = model.sample(args.pred_len, seq_start_end.cuda(), False, maps[args.obs_len - 1:], obs_traj[-1], dnames, h).cpu() pred_traj_sampled = relative_to_abs(pred_traj_sampled_rel, obs_traj[-1]) ade, ade_l, ade_nl = cal_ade(pred_traj_sampled, pred_traj_gt, linear_ped=None, non_linear_ped=None) fde, fde_l, fde_nl = cal_fde(pred_traj_sampled, pred_traj_gt, linear_ped=None, non_linear_ped=None) disp_error.append(ade.item()) disp_error_l.append(ade_l.item()) disp_error_nl.append(ade_nl.item()) f_disp_error.append(fde.item()) f_disp_error_l.append(fde_l.item()) f_disp_error_nl.append(fde_nl.item()) total_traj += pred_traj_gt.size(1) metrics['ade'] = sum(disp_error) / (total_traj * args.pred_len) metrics['ade_l'] = sum(disp_error_l) / (total_traj * args.pred_len) metrics['ade_nl'] = sum(disp_error_nl) / (total_traj * args.pred_len) metrics['fde'] = sum(f_disp_error) / total_traj metrics['fde_l'] = sum(f_disp_error_l) / total_traj metrics['fde_nl'] = sum(f_disp_error_nl) / total_traj writer.add_scalar('ade', metrics['ade'], epoch) writer.add_scalar('fde', metrics['fde'], epoch) mean_kld_loss /= len(test_loader) mean_nll_loss /= len(test_loader) mean_ade_loss /= len(test_loader) mean_kld_hm /= len(test_loader) writer.add_scalar('test_mean_kld_loss', mean_kld_loss, epoch) writer.add_scalar('test_mean_nll_loss', mean_nll_loss, epoch) if args.v_loss: writer.add_scalar('test_mean_ade_loss', mean_ade_loss, epoch) if args.use_hm: writer.add_scalar('test_mean_kld_hm', mean_kld_hm, epoch) writer.add_scalar( 'loss_test', mean_kld_loss + mean_nll_loss + mean_ade_loss + mean_kld_hm, epoch) print( '====> Test set loss: KLD Loss = {:.4f}, NLL Loss = {:.4f}, ADE = {:.4f}, KLD_HM = {:.4f} ' .format(mean_kld_loss, mean_nll_loss, mean_ade_loss, mean_kld_hm)) print(metrics)
def train(epoch, train_loader, optimizer, model, args, writer, beta_vals): train_loss = 0 mean_kld_loss, mean_nll_loss, mean_ade_loss, mean_kld_hm = 0, 0, 0, 0 disp_error, disp_error_l, disp_error_nl = ([], ) * 3 f_disp_error, f_disp_error_l, f_disp_error_nl = ([], ) * 3 total_traj, total_traj_l, total_traj_nl = 0, 0, 0 metrics = {} model.train() beta = beta_vals[epoch] for batch_idx, batch in enumerate(train_loader): (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, seq_start_end, maps, dnames) = batch if args.adj_type == 0: adj_out = compute_adjs(args, seq_start_end) elif args.adj_type == 1: adj_out = compute_adjs_distsim(args, seq_start_end, obs_traj, pred_traj_gt) elif args.adj_type == 2: adj_out = compute_adjs_knnsim(args, seq_start_end, obs_traj, pred_traj_gt) # Forward + backward + optimize optimizer.zero_grad() kld_loss, nll_loss, kld_hm, h = model(obs_traj_rel.cuda(), adj_out.cuda(), seq_start_end.cuda(), obs_traj[0], maps[:args.obs_len], epoch) v_losses = [] if args.v_loss: h_samples = torch.cat(args.k_vloss * [h], 1) pred_traj_rel = model.sample(args.pred_len, seq_start_end.cuda(), True, maps[args.obs_len - 1:], obs_traj[-1], dnames, h_samples) pred_traj_rel = torch.stack( torch.chunk(pred_traj_rel, args.k_vloss, dim=1)) for k in range(0, args.k_vloss): pred_traj_abs = relative_to_abs(pred_traj_rel[k], obs_traj[-1]) ade_loss = displacement_error( pred_traj_abs, pred_traj_gt) / obs_traj_rel.size(1) v_losses.append(ade_loss) ade_min = min(v_losses).cuda() mean_ade_loss += ade_min.item() loss = beta * kld_loss + nll_loss + ade_min + kld_hm else: loss = beta * kld_loss + nll_loss + kld_hm mean_kld_loss += kld_loss.item() mean_nll_loss += nll_loss.item() mean_kld_hm += kld_hm.item() loss.backward() # Clipping gradients nn.utils.clip_grad_norm_(model.parameters(), args.clip) optimizer.step() train_loss += loss.item() # Printing if batch_idx % args.print_every == 0: print( 'Train Epoch: {} [{}/{} ({:.0f}%)]\t KLD Loss: {:.6f} \t NLL Loss: {:.6f} \t KLD_hm: {:.6f}' .format(epoch, batch_idx * len(batch), len(train_loader.dataset), 100. * batch_idx / len(train_loader), kld_loss.item(), nll_loss.item(), kld_hm.item())) with torch.no_grad(): pred_traj_sampled_rel = model.sample(args.pred_len, seq_start_end.cuda(), False, maps[args.obs_len - 1:], obs_traj[-1], dnames, h).cpu() pred_traj_sampled = relative_to_abs(pred_traj_sampled_rel, obs_traj[-1]) ade, ade_l, ade_nl = cal_ade(pred_traj_sampled, pred_traj_gt, linear_ped=None, non_linear_ped=None) fde, fde_l, fde_nl = cal_fde(pred_traj_sampled, pred_traj_gt, linear_ped=None, non_linear_ped=None) disp_error.append(ade.item()) disp_error_l.append(ade_l.item()) disp_error_nl.append(ade_nl.item()) f_disp_error.append(fde.item()) f_disp_error_l.append(fde_l.item()) f_disp_error_nl.append(fde_nl.item()) total_traj += pred_traj_gt.size(1) # Plot samples # Input observations (obs_len, x_len) start, end = seq_start_end[0][0], seq_start_end[0][1] input_a = obs_traj[:, start:end, :].data # Ground truth (pred_len, x_len) gt = pred_traj_gt[:, start:end, :].data out_a = pred_traj_sampled[:, start:end, :].data gt_r = np.insert(np.asarray(gt.cpu()), 0, np.asarray(input_a[-1].unsqueeze(0).cpu()), axis=0) out_a_r = np.insert(np.asarray(out_a.cpu()), 0, np.asarray(input_a[-1].unsqueeze(0).cpu()), axis=0) img2 = draw_all_trj_seq(np.asarray(input_a.cpu()), gt_r, out_a_r, args) writer.add_figure('Generated_samples_in_absolute_coordinates', img2, epoch) metrics['ade'] = sum(disp_error) / (total_traj * args.pred_len) metrics['ade_l'] = sum(disp_error_l) / (total_traj * args.pred_len) metrics['ade_nl'] = sum(disp_error_nl) / (total_traj * args.pred_len) metrics['fde'] = sum(f_disp_error) / total_traj metrics['fde_l'] = sum(f_disp_error_l) / total_traj metrics['fde_nl'] = sum(f_disp_error_nl) / total_traj writer.add_scalar('ade', metrics['ade'], epoch) writer.add_scalar('fde', metrics['fde'], epoch) mean_kld_loss /= len(train_loader) mean_nll_loss /= len(train_loader) mean_ade_loss /= len(train_loader) mean_kld_hm /= len(train_loader) writer.add_scalar('train_mean_kld_loss', mean_kld_loss, epoch) writer.add_scalar('train_mean_nll_loss', mean_nll_loss, epoch) if args.v_loss: writer.add_scalar('train_mean_ade_loss', mean_ade_loss, epoch) if args.use_hm: writer.add_scalar('train_mean_kld_hm', mean_kld_hm, epoch) writer.add_scalar('loss_train', train_loss / len(train_loader), epoch) print('====> Epoch: {} Average loss: {:.4f}'.format( epoch, train_loss / len(train_loader))) print(metrics)
def train(epoch, train_loader, optimizer, model, args, writer, beta_vals): train_loss = 0 loss_mask_sum = 0 disp_error, disp_error_l, disp_error_nl = ([],) * 3 f_disp_error, f_disp_error_l, f_disp_error_nl = ([],) * 3 total_traj, total_traj_l, total_traj_nl = 0, 0, 0 l2_losses_abs, l2_losses_rel = ([],) * 2 metrics = {} model.train() for batch_idx, batch in enumerate(train_loader): (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, non_linear_ped, loss_mask, seq_start_end, maps) = batch loss_mask = loss_mask[:, args.obs_len:] linear_ped = 1 - non_linear_ped # Forward + backward + optimize optimizer.zero_grad() model = model.to(device) kld_loss, nll_loss, (x_list, mean_list), h = model(obs_traj_rel.cuda(), obs_traj[0]) beta = beta_vals[epoch] v_losses = [] if args.v_loss: for i in range(0, args.k_vloss): pred_traj_rel = model.sample(args.pred_len, obs_traj_rel.size(1), h) pred_traj_abs = relative_to_abs(pred_traj_rel, obs_traj[-1]) ade_loss = displacement_error(pred_traj_abs, pred_traj_gt) / obs_traj_rel.size(1) v_losses.append(ade_loss) ade_min = min(v_losses) loss = beta * kld_loss + nll_loss + ade_min else: loss = beta * kld_loss + nll_loss loss.backward() # Clipping gradients nn.utils.clip_grad_norm_(model.parameters(), args.clip) optimizer.step() train_loss += loss.item() # Printing if batch_idx % args.print_every == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\t KLD Loss: {:.6f} \t NLL Loss: {:.6f}'.format( epoch, batch_idx * len(batch), len(train_loader.dataset), 100. * batch_idx / len(train_loader), kld_loss.item(), nll_loss.item())) pred_traj_sampled_rel = model.sample(args.pred_len, obs_traj_rel.size(1), h) pred_traj_sampled = relative_to_abs(pred_traj_sampled_rel, obs_traj[-1]) pred_traj_gt_rel = pred_traj_gt_rel pred_traj_gt = pred_traj_gt ade, ade_l, ade_nl = cal_ade(pred_traj_sampled, pred_traj_gt, linear_ped, non_linear_ped) fde, fde_l, fde_nl = cal_fde(pred_traj_sampled, pred_traj_gt, linear_ped, non_linear_ped) l2_loss_abs, l2_loss_rel = cal_l2_losses(pred_traj_gt, pred_traj_gt_rel, pred_traj_sampled_rel, pred_traj_sampled, loss_mask) l2_losses_abs.append(l2_loss_abs.item()) l2_losses_rel.append(l2_loss_rel.item()) disp_error.append(ade.item()) disp_error_l.append(ade_l.item()) disp_error_nl.append(ade_nl.item()) f_disp_error.append(fde.item()) f_disp_error_l.append(fde_l.item()) f_disp_error_nl.append(fde_nl.item()) loss_mask_sum += torch.numel(loss_mask.data) total_traj += pred_traj_gt.size(1) total_traj_l += torch.sum(linear_ped).item() total_traj_nl += torch.sum(non_linear_ped).item() # Plot samples # Input observations (obs_len, x_len) start, end = seq_start_end[0][0], seq_start_end[0][1] input_a = obs_traj[:, start:end, :].data # Ground truth (pred_len, x_len) gt = pred_traj_gt[:, start:end, :].data out_a = pred_traj_sampled[:, start:end, :].data gt_r = np.insert(np.asarray(gt.cpu()), 0, np.asarray(input_a[-1].unsqueeze(0).cpu()), axis=0) out_a_r = np.insert(np.asarray(out_a.cpu()), 0, np.asarray(input_a[-1].unsqueeze(0).cpu()), axis=0) img2 = draw_all_trj_seq(np.asarray(input_a.cpu()), gt_r, out_a_r, args) writer.add_figure('Generated_samples_in_absolute_coordinates', img2, epoch) metrics['l2_loss_abs'] = sum(l2_losses_abs) / loss_mask_sum metrics['l2_loss_rel'] = sum(l2_losses_rel) / loss_mask_sum metrics['ade'] = sum(disp_error) / (total_traj * args.pred_len) metrics['fde'] = sum(f_disp_error) / total_traj writer.add_scalar('ade', metrics['ade'], epoch) writer.add_scalar('fde', metrics['fde'], epoch) writer.add_scalar('loss_train', train_loss / len(train_loader.dataset), epoch) print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(train_loader.dataset))) print(metrics)
def test(epoch, test_loader, model, writer, beta_vals): """Use test data to evaluate likelihood of the model""" mean_kld_loss, mean_nll_loss = 0, 0 loss_mask_sum = 0 disp_error, disp_error_l, disp_error_nl = ([],) * 3 f_disp_error, f_disp_error_l, f_disp_error_nl = ([],) * 3 total_traj, total_traj_l, total_traj_nl = 0, 0, 0 l2_losses_abs, l2_losses_rel = ([],) * 2 metrics = {} model.eval() beta = beta_vals[epoch] with torch.no_grad(): for i, batch in enumerate(test_loader): (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, non_linear_ped, loss_mask, seq_start_end, maps) = batch loss_mask = loss_mask[:, args.obs_len:] linear_ped = 1 - non_linear_ped model = model.to(device) kld_loss, nll_loss, _, h = model(obs_traj_rel.cuda(), obs_traj[0]) mean_kld_loss += beta * kld_loss.item() v_losses = [] if args.v_loss: for j in range(0, args.k_vloss): pred_traj_rel = model.sample(args.pred_len, obs_traj_rel.size(1), h) pred_traj_abs = relative_to_abs(pred_traj_rel, obs_traj[-1]) ade_loss = displacement_error(pred_traj_abs, pred_traj_gt) / obs_traj_rel.size(1) v_losses.append(ade_loss) ade_min = min(v_losses) mean_nll_loss += (nll_loss.item() + ade_min.item()) else: mean_nll_loss += nll_loss.item() if i % args.print_every == 0: pred_traj_sampled_rel = model.sample(args.pred_len, obs_traj_rel.size(1), h) pred_traj_sampled = relative_to_abs(pred_traj_sampled_rel, obs_traj[-1]) pred_traj_gt_rel = pred_traj_gt_rel pred_traj_gt = pred_traj_gt ade, ade_l, ade_nl = cal_ade(pred_traj_sampled, pred_traj_gt, linear_ped, non_linear_ped) fde, fde_l, fde_nl = cal_fde(pred_traj_sampled, pred_traj_gt, linear_ped, non_linear_ped) l2_loss_abs, l2_loss_rel = cal_l2_losses(pred_traj_gt, pred_traj_gt_rel, pred_traj_sampled_rel, pred_traj_sampled, loss_mask) l2_losses_abs.append(l2_loss_abs.item()) l2_losses_rel.append(l2_loss_rel.item()) disp_error.append(ade.item()) disp_error_l.append(ade_l.item()) disp_error_nl.append(ade_nl.item()) f_disp_error.append(fde.item()) f_disp_error_l.append(fde_l.item()) f_disp_error_nl.append(fde_nl.item()) loss_mask_sum += torch.numel(loss_mask.data) total_traj += pred_traj_gt.size(1) total_traj_l += torch.sum(linear_ped).item() total_traj_nl += torch.sum(non_linear_ped).item() metrics['l2_loss_abs'] = sum(l2_losses_abs) / loss_mask_sum metrics['l2_loss_rel'] = sum(l2_losses_rel) / loss_mask_sum metrics['ade'] = sum(disp_error) / (total_traj * args.pred_len) metrics['fde'] = sum(f_disp_error) / total_traj writer.add_scalar('ade', metrics['ade'], epoch) writer.add_scalar('fde', metrics['fde'], epoch) mean_kld_loss /= len(test_loader.dataset) mean_nll_loss /= len(test_loader.dataset) writer.add_scalar('loss_test', mean_kld_loss + mean_nll_loss, epoch) print('====> Test set loss: KLD Loss = {:.4f}, NLL Loss = {:.4f} '.format(mean_kld_loss, mean_nll_loss)) print(metrics)
def check_accuracy_graph_sways(args, loader, model, epoch, limit=False): losses = [] val_loss = 0 metrics = {} disp_error = [] f_disp_error = [] total_traj = 0 model.eval() with torch.no_grad(): for batch in loader: (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, seq_start_end, maps, dnames) = batch if args.adj_type == 0: adj_out = compute_adjs(args, seq_start_end) elif args.adj_type == 1: adj_out = compute_adjs_distsim(args, seq_start_end, obs_traj, pred_traj_gt) elif args.adj_type == 2: adj_out = compute_adjs_knnsim(args, seq_start_end, obs_traj, pred_traj_gt) kld_loss, nll_loss, kld_hm, h = model(obs_traj_rel.cuda(), adj_out.cuda(), seq_start_end.cuda(), obs_traj[0], maps[:args.obs_len], epoch) loss = kld_loss + nll_loss + kld_hm val_loss += loss.item() pred_traj_rel = model.sample(args.pred_len, seq_start_end.cuda(), False, maps[args.obs_len - 1:], obs_traj[-1], dnames, h).cpu() pred_traj = relative_to_abs(pred_traj_rel, obs_traj[-1]) ade, ade_l, ade_nl = cal_ade(pred_traj_gt, pred_traj, linear_ped=None, non_linear_ped=None) fde, fde_l, fde_nl = cal_fde(pred_traj_gt, pred_traj, linear_ped=None, non_linear_ped=None) losses.append(loss.item()) disp_error.append(ade.item()) f_disp_error.append(fde.item()) total_traj += pred_traj_gt.size(1) if limit and total_traj >= args.num_samples_check: break metrics['loss'] = sum(losses) / len(losses) metrics['ade'] = sum(disp_error) / (total_traj * args.pred_len) metrics['fde'] = sum(f_disp_error) / total_traj metrics['ade_l'] = 0 metrics['fde_l'] = 0 metrics['ade_nl'] = 0 metrics['fde_nl'] = 0 model.train() return metrics, val_loss / len(loader)
def check_accuracy_graph(args, loader, model, epoch, limit=False): losses = [] val_loss = 0 metrics = {} l2_losses_abs, l2_losses_rel = ([], ) * 2 disp_error, disp_error_l, disp_error_nl = ([], ) * 3 f_disp_error, f_disp_error_l, f_disp_error_nl = ([], ) * 3 total_traj, total_traj_l, total_traj_nl = 0, 0, 0 loss_mask_sum = 0 model.eval() with torch.no_grad(): for batch in loader: (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, non_linear_ped, loss_mask, seq_start_end, maps, dnames) = batch linear_ped = 1 - non_linear_ped loss_mask = loss_mask[:, args.obs_len:] if args.adj_type == 0: adj_out = compute_adjs(args, seq_start_end) elif args.adj_type == 1: adj_out = compute_adjs_distsim(args, seq_start_end, obs_traj, pred_traj_gt) elif args.adj_type == 2: adj_out = compute_adjs_knnsim(args, seq_start_end, obs_traj, pred_traj_gt) kld_loss, nll_loss, kld_hm, h = model(obs_traj_rel.cuda(), adj_out.cuda(), seq_start_end.cuda(), obs_traj[0], maps[:args.obs_len], epoch) loss = kld_loss + nll_loss + kld_hm val_loss += loss.item() pred_traj_rel = model.sample(args.pred_len, seq_start_end.cuda(), False, maps[args.obs_len - 1:], obs_traj[-1], dnames, h).cpu() pred_traj = relative_to_abs(pred_traj_rel, obs_traj[-1]) l2_loss_abs, l2_loss_rel = cal_l2_losses(pred_traj_gt, pred_traj_gt_rel, pred_traj, pred_traj_rel, loss_mask) ade, ade_l, ade_nl = cal_ade(pred_traj_gt, pred_traj, linear_ped, non_linear_ped) fde, fde_l, fde_nl = cal_fde(pred_traj_gt, pred_traj, linear_ped, non_linear_ped) losses.append(loss.item()) l2_losses_abs.append(l2_loss_abs.item()) l2_losses_rel.append(l2_loss_rel.item()) disp_error.append(ade.item()) disp_error_l.append(ade_l.item()) disp_error_nl.append(ade_nl.item()) f_disp_error.append(fde.item()) f_disp_error_l.append(fde_l.item()) f_disp_error_nl.append(fde_nl.item()) loss_mask_sum += torch.numel(loss_mask.data) total_traj += pred_traj_gt.size(1) total_traj_l += torch.sum(linear_ped).item() total_traj_nl += torch.sum(non_linear_ped).item() if limit and total_traj >= args.num_samples_check: break metrics['loss'] = sum(losses) / len(losses) metrics['l2_loss_abs'] = sum(l2_losses_abs) / loss_mask_sum metrics['l2_loss_rel'] = sum(l2_losses_rel) / loss_mask_sum metrics['ade'] = sum(disp_error) / (total_traj * args.pred_len) metrics['fde'] = sum(f_disp_error) / total_traj if total_traj_l != 0: metrics['ade_l'] = sum(disp_error_l) / (total_traj_l * args.pred_len) metrics['fde_l'] = sum(f_disp_error_l) / total_traj_l else: metrics['ade_l'] = 0 metrics['fde_l'] = 0 if total_traj_nl != 0: metrics['ade_nl'] = sum(disp_error_nl) / (total_traj_nl * args.pred_len) metrics['fde_nl'] = sum(f_disp_error_nl) / total_traj_nl else: metrics['ade_nl'] = 0 metrics['fde_nl'] = 0 model.train() return metrics, val_loss / len(loader)
def check_accuracy_baseline(args, loader, model, limit=False): losses = [] metrics = {} val_loss = 0 l2_losses_abs, l2_losses_rel = ([], ) * 2 disp_error, disp_error_l, disp_error_nl = ([], ) * 3 f_disp_error, f_disp_error_l, f_disp_error_nl = ([], ) * 3 total_traj, total_traj_l, total_traj_nl = 0, 0, 0 loss_mask_sum = 0 model.eval() with torch.no_grad(): for batch in loader: (obs_traj, pred_traj_gt, obs_traj_rel, pred_traj_gt_rel, non_linear_ped, loss_mask, seq_start_end, maps, dnames) = batch linear_ped = 1 - non_linear_ped loss_mask = loss_mask[:, args.obs_len:] if args.model == 'vrnn': kld_loss, nll_loss, _, h = model(obs_traj_rel.cuda(), obs_traj[0]) loss = kld_loss + nll_loss elif args.model == 'rnn': loss, _, h = model(obs_traj_rel.cuda()) val_loss += loss.item() pred_traj_rel = model.sample(args.pred_len, obs_traj_rel.size(1), obs_traj[-1], dnames, h) pred_traj = relative_to_abs(pred_traj_rel, obs_traj[-1]) l2_loss_abs, l2_loss_rel = cal_l2_losses(pred_traj_gt, pred_traj_gt_rel, pred_traj, pred_traj_rel, loss_mask) ade, ade_l, ade_nl = cal_ade(pred_traj_gt, pred_traj, linear_ped, non_linear_ped) fde, fde_l, fde_nl = cal_fde(pred_traj_gt, pred_traj, linear_ped, non_linear_ped) losses.append(loss.item()) l2_losses_abs.append(l2_loss_abs.item()) l2_losses_rel.append(l2_loss_rel.item()) disp_error.append(ade.item()) disp_error_l.append(ade_l.item()) disp_error_nl.append(ade_nl.item()) f_disp_error.append(fde.item()) f_disp_error_l.append(fde_l.item()) f_disp_error_nl.append(fde_nl.item()) loss_mask_sum += torch.numel(loss_mask.data) total_traj += pred_traj_gt.size(1) total_traj_l += torch.sum(linear_ped).item() total_traj_nl += torch.sum(non_linear_ped).item() if limit and total_traj >= args.num_samples_check: break metrics['loss'] = sum(losses) / len(losses) metrics['l2_loss_abs'] = sum(l2_losses_abs) / loss_mask_sum metrics['l2_loss_rel'] = sum(l2_losses_rel) / loss_mask_sum metrics['ade'] = sum(disp_error) / (total_traj * args.pred_len) metrics['fde'] = sum(f_disp_error) / total_traj if total_traj_l != 0: metrics['ade_l'] = sum(disp_error_l) / (total_traj_l * args.pred_len) metrics['fde_l'] = sum(f_disp_error_l) / total_traj_l else: metrics['ade_l'] = 0 metrics['fde_l'] = 0 if total_traj_nl != 0: metrics['ade_nl'] = sum(disp_error_nl) / (total_traj_nl * args.pred_len) metrics['fde_nl'] = sum(f_disp_error_nl) / total_traj_nl else: metrics['ade_nl'] = 0 metrics['fde_nl'] = 0 model.train() return metrics, val_loss / len(loader)