def __getitem__(self, index): # pdb.set_trace() if self.load_saved and self.mode == "train": with open( f"/home/scratch/nitinsin/argoverse_social_centerline/train/{index}.pkl", 'rb') as f: train_dict = pickle.load(f) return train_dict if self.load_saved and self.mode == "validate": with open( f"/home/scratch/nitinsin/argoverse_social_centerline/val/{index}.pkl", 'rb') as f: val_dict = pickle.load(f) return val_dict current_loader = self.afl.get(self.seq_paths[index]) agent_traj = current_loader.agent_traj neighbours_traj = current_loader.neighbour_traj() social_features, map_features, map_feature_helpers = compute_features( self.seq_paths[index], self.map_features_utils_instance, self.social_features_utils_instance, self.avm, 'train') unnorm_traj = get_xy_from_nt_seq(np.expand_dims( map_features, axis=0), [map_feature_helpers["ORACLE_CENTERLINE"]]) norm = np.linalg.norm(unnorm_traj - agent_traj) ref_t = map_features[self.train_seq_size - 1, 1] map_features[:, 1] = map_features[:, 1] - ref_t neighbour_centerline_frame = self.convert_neighbour_centerline( neighbours_traj, map_feature_helpers["ORACLE_CENTERLINE"], ref_t) if self.mode == "train": return_dict = { 'seq_path': self.seq_paths[index], 'train_traj': map_features[:self.train_seq_size, :], 'gt_traj': map_features[self.train_seq_size:, :], 'neighbours': neighbour_centerline_frame, 'helpers': map_feature_helpers, 'norm': norm, 'ref_t': ref_t, 'social_features': social_features } elif self.mode == "validate": return_dict = { 'seq_path': self.seq_paths[index], 'train_traj': map_features[:self.train_seq_size, :], 'gt_traj': map_features[self.train_seq_size:, :], 'gt_unnorm_traj': agent_traj[self.train_seq_size:, :], 'neighbours': neighbour_centerline_frame, 'helpers': map_feature_helpers, 'norm': norm, 'ref_t': ref_t, 'social_features': social_features } if self.save: with open( f"/home/scratch/nitinsin/argoverse_social_centerline/val/{index}.pkl", 'wb') as f: pickle.dump(return_dict, f) return return_dict
def inverse_transform(self, trajectory, traj_dict): centerline = traj_dict['centerline'] if self.use_cuda: trajectory = trajectory.cpu() out = get_xy_from_nt_seq(nt_seq=trajectory, centerlines=centerline) out = torch.Tensor(out).float() if self.use_cuda: out = out.cuda() return out
def __getitem__(self, index): if self.mode == "train" or self.mode == "validate": # import pdb;pdb.set_trace() if self.load_saved and self.mode == "train": with open( f"/home/scratch/nitinsin/argoverse/train/{index}.pkl", 'rb') as f: train_dict = pickle.load(f) return train_dict if self.load_saved and self.mode == "validate": with open(f"/home/scratch/nitinsin/argoverse/val/{index}.pkl", 'rb') as f: val_dict = pickle.load(f) return val_dict current_loader = self.afl.get(self.seq_paths[index]) agent_traj = current_loader.agent_traj social_features, map_features, map_feature_helpers = compute_features( self.seq_paths[index], self.map_features_utils_instance, self.social_features_utils_instance, self.avm, 'train') # social_features,map_features,map_feature_helpers = self.compute_features_old( # self.seq_paths[index], None,None,None,'train') unnorm_traj = get_xy_from_nt_seq( np.expand_dims(map_features, axis=0), [map_feature_helpers["ORACLE_CENTERLINE"]]) norm = np.linalg.norm(unnorm_traj - agent_traj) # if norm>1.0: # print(f"Norm at index {index}",norm) ref_t = map_features[self.train_seq_size - 1, 1] map_features[:, 1] = map_features[:, 1] - ref_t if self.mode == "train": return_dict = { 'seq_path': self.seq_paths[index], 'train_traj': map_features[:self.train_seq_size, :], 'gt_traj': map_features[self.train_seq_size:, :], 'helpers': map_feature_helpers, 'norm': norm, 'ref_t': ref_t, 'social_features': social_features } if self.save: with open( f"/home/scratch/nitinsin/argoverse/train/{index}.pkl", 'wb') as f: pickle.dump(return_dict, f) else: return_dict = { 'seq_path': self.seq_paths[index], 'train_traj': map_features[:self.train_seq_size, :], 'gt_unnorm_traj': agent_traj[self.train_seq_size:, :], 'helpers': map_feature_helpers, 'norm': norm, 'ref_t': ref_t, 'social_features': social_features } if self.save: with open( f"/home/scratch/nitinsin/argoverse/val/{index}.pkl", 'wb') as f: pickle.dump(return_dict, f) return return_dict # return {'seq_path':self.seq_paths[index],'train_unnorm_traj': agent_traj[:self.train_seq_size,:], # 'train_traj':map_features[:self.train_seq_size,:],'gt_traj':map_features[self.train_seq_size:,:], # 'gt_unnorm_traj':agent_traj[self.train_seq_size:,:],'helpers':map_feature_helpers, # 'norm_traj':map_features,'unnorm_traj':agent_traj} elif self.mode == "validate_multiple": current_loader = self.afl.get(self.seq_paths[index]) agent_traj = current_loader.agent_traj social_features, map_features, map_feature_helpers = compute_features( self.seq_paths[index], self.map_features_utils_instance, self.social_features_utils_instance, self.avm, 'test') return { 'seq_path': self.seq_paths[index], 'helpers': map_feature_helpers, 'train_unnorm_traj': agent_traj[0:self.train_seq_size, :], 'gt_unnorm_traj': agent_traj[self.train_seq_size:, :], 'city': current_loader.city, 'norm': 0.0 } elif self.mode == "test": social_features, map_features, map_feature_helpers = compute_features( self.seq_paths[index], self.map_features_utils_instance, self.social_features_utils_instance, self.avm, 'test') return { 'seq_path': self.seq_paths[index], 'helpers': map_feature_helpers }
collate_fn=collate_traj_multilane) dataloader_val = DataLoader(dataset_val, batch_size=64, shuffle=False, num_workers=8, collate_fn=collate_traj_multilane) all_correct_seq_path = [] # selected=0 total = 0 for i, traj_dict in enumerate(dataloader_train): gt_traj = traj_dict['gt_traj'] gt_unnorm_traj = traj_dict['gt_unnorm_traj'].numpy() all_centerlines = [ helper["ORACLE_CENTERLINE"] for helper in traj_dict['helpers'] ] pred_unnorm_traj = get_xy_from_nt_seq(gt_traj.numpy(), all_centerlines) norm = np.linalg.norm(pred_unnorm_traj - gt_unnorm_traj, axis=(1, 2)) index = norm < 0.05 seq_paths = traj_dict['seq_path'] all_correct_seq_path.extend( [seq_paths[i] for i in range(len(seq_paths)) if index[i] == True]) total += len(seq_paths) print(f"{len(all_correct_seq_path)}/{total} selected for train", end="\r") print() with open("train.pkl", 'wb') as f: pickle.dump(all_correct_seq_path, f) # for i,traj_dict in enumerate(dataloader_val): # gt_traj=traj_dict['gt_traj'] # gt_unnorm_traj=traj_dict['gt_unnorm_traj'].numpy()