# Forward pass if params.use_maneuvers: if epoch_num < pretrainEpochs: # During pre-training with MSE loss, validate with MSE for true maneuver class trajectory net.train_flag = True fut_pred, _, _ = net(hist, nbrs, mask, lat_enc, lon_enc, hist_grid) l = maskedMSE(fut_pred, fut, op_mask) else: # During training with NLL loss, validate with NLL over multi-modal distribution fut_pred, lat_pred, lon_pred = net(hist, nbrs, mask, lat_enc, lon_enc, hist_grid) l = maskedNLLTest(fut_pred, lat_pred, lon_pred, fut, op_mask, avg_along_time=True) avg_val_lat_acc += (torch.sum( torch.max(lat_pred.data, 1)[1] == torch.max( lat_enc.data, 1)[1])).item() / lat_enc.size()[0] avg_val_lon_acc += (torch.sum( torch.max(lon_pred.data, 1)[1] == torch.max( lon_enc.data, 1)[1])).item() / lon_enc.size()[0] else: fut_pred = net(hist, nbrs, mask, lat_enc, lon_enc, hist_grid) if epoch_num < pretrainEpochs: l = maskedMSE(fut_pred, fut, op_mask) else: l = maskedNLL(fut_pred, fut, op_mask)
def train_model(): args = parser.parse_args() print("------------- {} -------------".format(args.name)) print("Batch size : {}".format(args.batch_size)) print("Learning rate : {}".format(args.learning_rate)) print("Use Planning Coupled: {}".format(args.use_planning)) print("Use Target Fusion: {}".format(args.use_fusion)) ## Initialize network and optimizer PiP = pipNet(args) if args.use_cuda: PiP = PiP.cuda() optimizer = torch.optim.Adam(PiP.parameters(), lr=args.learning_rate) crossEnt = torch.nn.BCELoss() ## Initialize the log folder log_path = "./trained_models/{}/".format(args.name) if not os.path.exists(log_path): os.makedirs(log_path) if args.tensorboard: logger = SummaryWriter(log_path + 'train-pre{}-nll{}'.format(args.pretrain_epochs, args.train_epochs)) logger_val = SummaryWriter(log_path + 'validation-pre{}-nll{}'.format(args.pretrain_epochs, args.train_epochs)) ## Initialize training parameters pretrainEpochs = args.pretrain_epochs trainEpochs = args.train_epochs batch_size = args.batch_size ## Initialize data loaders print("Train dataset: {}".format(args.train_set)) trSet = highwayTrajDataset(path=args.train_set, targ_enc_size=args.social_context_size+args.dynamics_encoding_size, grid_size=args.grid_size, fit_plan_traj=False) print("Validation dataset: {}".format(args.val_set)) valSet = highwayTrajDataset(path=args.val_set, targ_enc_size=args.social_context_size+args.dynamics_encoding_size, grid_size=args.grid_size, fit_plan_traj=True) trDataloader = DataLoader(trSet, batch_size=batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=trSet.collate_fn) valDataloader = DataLoader(valSet, batch_size=batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=valSet.collate_fn) print("DataSet Prepared : {} train data, {} validation data\n".format(len(trSet), len(valSet))) print("Network structure: {}\n".format(PiP)) ## Training process for epoch_num in range( pretrainEpochs + trainEpochs ): if epoch_num == 0: print('Pretrain with MSE loss') elif epoch_num == pretrainEpochs: print('Train with NLL loss') ## Variables to track training performance: avg_time_tr, avg_loss_tr, avg_loss_val = 0, 0, 0 ## Training status, reclaim after each epoch PiP.train() PiP.train_output_flag = True for i, data in enumerate(trDataloader): st_time = time.time() nbsHist, nbsMask, planFut, planMask, targsHist, targsEncMask, targsFut, targsFutMask, lat_enc, lon_enc, _ = data if args.use_cuda: nbsHist = nbsHist.cuda() nbsMask = nbsMask.cuda() planFut = planFut.cuda() planMask = planMask.cuda() targsHist = targsHist.cuda() targsEncMask = targsEncMask.cuda() lat_enc = lat_enc.cuda() lon_enc = lon_enc.cuda() targsFut = targsFut.cuda() targsFutMask = targsFutMask.cuda() # Forward pass fut_pred, lat_pred, lon_pred = PiP(nbsHist, nbsMask, planFut, planMask, targsHist, targsEncMask, lat_enc, lon_enc) if epoch_num < pretrainEpochs: # Pre-train with MSE loss to speed up training l = maskedMSE(fut_pred, targsFut, targsFutMask) else: # Train with NLL loss l = maskedNLL(fut_pred, targsFut, targsFutMask) + crossEnt(lat_pred, lat_enc) + crossEnt(lon_pred, lon_enc) # Back-prop and update weights optimizer.zero_grad() l.backward() prev_vec_norm = torch.nn.utils.clip_grad_norm_(PiP.parameters(), 10) optimizer.step() # Track average train loss and average train time: batch_time = time.time()-st_time avg_loss_tr += l.item() avg_time_tr += batch_time # For every 100 batches: record loss, validate model, and plot. if i%100 == 99: eta = avg_time_tr/100*(len(trSet)/batch_size-i) epoch_progress = i * batch_size / len(trSet) print("Epoch no:",epoch_num+1, "| Epoch progress(%):",format(epoch_progress*100,'0.2f'), "| Avg train loss:",format(avg_loss_tr/100,'0.2f'), "| ETA(s):",int(eta)) if args.tensorboard: logger.add_scalar("RMSE" if epoch_num < pretrainEpochs else "NLL", avg_loss_tr / 100, (epoch_progress + epoch_num) * 100) ## Validatation during training: eval_batch_num = 20 with torch.no_grad(): PiP.eval() PiP.train_output_flag = False for i, data in enumerate(valDataloader): nbsHist, nbsMask, planFut, planMask, targsHist, targsEncMask, targsFut, targsFutMask, lat_enc, lon_enc, _ = data if args.use_cuda: nbsHist = nbsHist.cuda() nbsMask = nbsMask.cuda() planFut = planFut.cuda() planMask = planMask.cuda() targsHist = targsHist.cuda() targsEncMask = targsEncMask.cuda() lat_enc = lat_enc.cuda() lon_enc = lon_enc.cuda() targsFut = targsFut.cuda() targsFutMask = targsFutMask.cuda() if epoch_num < pretrainEpochs: # During pre-training with MSE loss, validate with MSE for true maneuver class trajectory PiP.train_output_flag = True fut_pred, _, _ = PiP(nbsHist, nbsMask, planFut, planMask, targsHist, targsEncMask, lat_enc, lon_enc) l = maskedMSE(fut_pred, targsFut, targsFutMask) else: # During training with NLL loss, validate with NLL over multi-modal distribution fut_pred, lat_pred, lon_pred = PiP(nbsHist, nbsMask, planFut, planMask, targsHist, targsEncMask, lat_enc, lon_enc) l = maskedNLLTest(fut_pred, lat_pred, lon_pred, targsFut, targsFutMask, avg_along_time=True) avg_loss_val += l.item() if i==(eval_batch_num-1): if args.tensorboard: logger_val.add_scalar("RMSE" if epoch_num < pretrainEpochs else "NLL", avg_loss_val / eval_batch_num, (epoch_progress + epoch_num) * 100) break # Clear statistic avg_time_tr, avg_loss_tr, avg_loss_val = 0, 0, 0 # Revert to train mode after in-process evaluation. PiP.train() PiP.train_output_flag = True ## Save the model after each epoch______________________________________________________________________________ epoCount = epoch_num + 1 if epoCount < pretrainEpochs: torch.save(PiP.state_dict(), log_path + "{}-pre{}-nll{}.tar".format(args.name, epoCount, 0)) else: torch.save(PiP.state_dict(), log_path + "{}-pre{}-nll{}.tar".format(args.name, pretrainEpochs, epoCount - pretrainEpochs)) # All epochs finish________________________________________________________________________________________________ torch.save(PiP.state_dict(), log_path+"{}.tar".format(args.name)) print("Model saved in trained_models/{}/{}.tar\n".format(args.name, args.name))
for i, data in enumerate(tsDataloader): st_time = time.time() hist, nbrs, mask, lat_enc, lon_enc, fut, op_mask, graph_list, _, _, nbrs_idx = data # Initialize Variables if args['use_cuda']: hist = hist.to(args['device']) nbrs = nbrs.to(args['device']) mask = mask.to(args['device']) fut = fut.to(args['device']) op_mask = op_mask.to(args['device']) nbrs_idx = nbrs_idx.to(args['device']) if metric == 'nll': # Forward pass fut_pred = net(hist, nbrs, mask, graph_list, nbrs_idx, mode='path') l, c = maskedNLLTest(fut_pred, 0, 0, fut, op_mask, use_maneuvers=False) else: # Forward pass fut_pred = net(hist, nbrs, mask, graph_list, nbrs_idx, mode='path') l, c = maskedMSETest(fut_pred, fut, op_mask) lossVals += l.detach() counts += c.detach() if metric == 'nll': print(lossVals / counts) else: print(torch.pow(lossVals / counts, 0.5) * 0.3048) # Calculate RMSE and convert from feet to meters
def model_evaluate(): args = parser.parse_args() ## Initialize network PiP = pipNet(args) PiP.load_state_dict( torch.load('./trained_models/{}/{}.tar'.format( (args.name).split('-')[0], args.name))) if args.use_cuda: PiP = PiP.cuda() ## Evaluation Mode PiP.eval() PiP.train_output_flag = False initLogging(log_file='./trained_models/{}/evaluation.log'.format(( args.name).split('-')[0])) ## Intialize dataset logging.info("Loading test data from {}...".format(args.test_set)) tsSet = highwayTrajDataset(path=args.test_set, targ_enc_size=args.social_context_size + args.dynamics_encoding_size, grid_size=args.grid_size, fit_plan_traj=True, fit_plan_further_ds=args.plan_info_ds) logging.info("TOTAL :: {} test data.".format(len(tsSet))) tsDataloader = DataLoader(tsSet, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, collate_fn=tsSet.collate_fn) ## Loss statistic logging.info( "<{}> evaluated by {}-based NLL & RMSE, with planning input of {}s step." .format(args.name, args.metric, args.plan_info_ds * 0.2)) if args.metric == 'agent': nll_loss_stat = np.zeros( (np.max(tsSet.Data[:, 0]).astype(int) + 1, np.max(tsSet.Data[:, 13:(13 + tsSet.grid_cells)]).astype(int) + 1, args.out_length)) rmse_loss_stat = np.zeros( (np.max(tsSet.Data[:, 0]).astype(int) + 1, np.max(tsSet.Data[:, 13:(13 + tsSet.grid_cells)]).astype(int) + 1, args.out_length)) both_count_stat = np.zeros( (np.max(tsSet.Data[:, 0]).astype(int) + 1, np.max(tsSet.Data[:, 13:(13 + tsSet.grid_cells)]).astype(int) + 1, args.out_length)) elif args.metric == 'sample': rmse_loss = torch.zeros(25).cuda() rmse_counts = torch.zeros(25).cuda() nll_loss = torch.zeros(25).cuda() nll_counts = torch.zeros(25).cuda() else: raise RuntimeError("Wrong type of evaluation metric is specified") avg_eva_time = 0 ## Evaluation process with torch.no_grad(): for i, data in enumerate(tsDataloader): st_time = time.time() nbsHist, nbsMask, planFut, planMask, targsHist, targsEncMask, targsFut, targsFutMask, lat_enc, lon_enc, idxs = data # Initialize Variables if args.use_cuda: nbsHist = nbsHist.cuda() nbsMask = nbsMask.cuda() planFut = planFut.cuda() planMask = planMask.cuda() targsHist = targsHist.cuda() targsEncMask = targsEncMask.cuda() lat_enc = lat_enc.cuda() lon_enc = lon_enc.cuda() targsFut = targsFut.cuda() targsFutMask = targsFutMask.cuda() # Inference fut_pred, lat_pred, lon_pred = PiP(nbsHist, nbsMask, planFut, planMask, targsHist, targsEncMask, lat_enc, lon_enc) # Performance metric if args.metric == 'agent': dsIDs, targsIDs = tsSet.batchTargetVehsInfo(idxs) l, c = maskedNLLTest(fut_pred, lat_pred, lon_pred, targsFut, targsFutMask, separately=True) # Select the trajectory with the largest probability of maneuver label when evaluating by RMSE fut_pred_max = torch.zeros_like(fut_pred[0]) for k in range(lat_pred.shape[0]): lat_man = torch.argmax(lat_pred[k, :]).detach() lon_man = torch.argmax(lon_pred[k, :]).detach() indx = lon_man * 3 + lat_man fut_pred_max[:, k, :] = fut_pred[indx][:, k, :] # Using the most probable trajectory ll, cc = maskedMSETest(fut_pred_max, targsFut, targsFutMask, separately=True) l = l.detach().cpu().numpy() ll = ll.detach().cpu().numpy() c = c.detach().cpu().numpy() cc = cc.detach().cpu().numpy() for j, targ in enumerate(targsIDs): dsID = dsIDs[j] nll_loss_stat[dsID, targ, :] += l[:, j] rmse_loss_stat[dsID, targ, :] += ll[:, j] both_count_stat[dsID, targ, :] += c[:, j] elif args.metric == 'sample': l, c = maskedNLLTest(fut_pred, lat_pred, lon_pred, targsFut, targsFutMask) nll_loss += l.detach() nll_counts += c.detach() fut_pred_max = torch.zeros_like(fut_pred[0]) for k in range(lat_pred.shape[0]): lat_man = torch.argmax(lat_pred[k, :]).detach() lon_man = torch.argmax(lon_pred[k, :]).detach() indx = lon_man * 3 + lat_man fut_pred_max[:, k, :] = fut_pred[indx][:, k, :] l, c = maskedMSETest(fut_pred_max, targsFut, targsFutMask) rmse_loss += l.detach() rmse_counts += c.detach() # Time estimate batch_time = time.time() - st_time avg_eva_time += batch_time if i % 100 == 99: eta = avg_eva_time / 100 * (len(tsSet) / args.batch_size - i) logging.info("Evaluation progress(%):{:.2f}".format( i / (len(tsSet) / args.batch_size) * 100, ) + " | ETA(s):{}".format(int(eta))) avg_eva_time = 0 # Result Summary if args.metric == 'agent': # Loss averaged from all predicted vehicles. ds_ids, veh_ids = both_count_stat[:, :, 0].nonzero() num_vehs = len(veh_ids) rmse_loss_averaged = np.zeros((args.out_length, num_vehs)) nll_loss_averaged = np.zeros((args.out_length, num_vehs)) count_averaged = np.zeros((args.out_length, num_vehs)) for i in range(num_vehs): count_averaged[:, i] = \ both_count_stat[ds_ids[i], veh_ids[i], :].astype(bool) rmse_loss_averaged[:,i] = rmse_loss_stat[ds_ids[i], veh_ids[i], :] \ * count_averaged[:, i] / (both_count_stat[ds_ids[i], veh_ids[i], :] + 1e-9) nll_loss_averaged[:,i] = nll_loss_stat[ds_ids[i], veh_ids[i], :] \ * count_averaged[:, i] / (both_count_stat[ds_ids[i], veh_ids[i], :] + 1e-9) rmse_loss_sum = np.sum(rmse_loss_averaged, axis=1) nll_loss_sum = np.sum(nll_loss_averaged, axis=1) count_sum = np.sum(count_averaged, axis=1) rmseOverall = np.power( rmse_loss_sum / count_sum, 0.5) * 0.3048 # Unit converted from feet to meter. nllOverall = nll_loss_sum / count_sum elif args.metric == 'sample': rmseOverall = (torch.pow(rmse_loss / rmse_counts, 0.5) * 0.3048).cpu() nllOverall = (nll_loss / nll_counts).cpu() # Print the metrics every 5 time frame (1s) logging.info("RMSE (m)\t=> {}, Mean={:.3f}".format( rmseOverall[4::5], rmseOverall[4::5].mean())) logging.info("NLL (nats)\t=> {}, Mean={:.3f}".format( nllOverall[4::5], nllOverall[4::5].mean()))