Exemplo n.º 1
0
    def __getitem__(self, index):
        # pdb.set_trace()
        if self.load_saved and self.mode == "train":
            with open(
                    f"/home/scratch/nitinsin/argoverse_social_centerline/train/{index}.pkl",
                    'rb') as f:
                train_dict = pickle.load(f)
            return train_dict
        if self.load_saved and self.mode == "validate":
            with open(
                    f"/home/scratch/nitinsin/argoverse_social_centerline/val/{index}.pkl",
                    'rb') as f:
                val_dict = pickle.load(f)
            return val_dict

        current_loader = self.afl.get(self.seq_paths[index])
        agent_traj = current_loader.agent_traj
        neighbours_traj = current_loader.neighbour_traj()
        social_features, map_features, map_feature_helpers = compute_features(
            self.seq_paths[index], self.map_features_utils_instance,
            self.social_features_utils_instance, self.avm, 'train')

        unnorm_traj = get_xy_from_nt_seq(np.expand_dims(
            map_features, axis=0), [map_feature_helpers["ORACLE_CENTERLINE"]])
        norm = np.linalg.norm(unnorm_traj - agent_traj)
        ref_t = map_features[self.train_seq_size - 1, 1]
        map_features[:, 1] = map_features[:, 1] - ref_t
        neighbour_centerline_frame = self.convert_neighbour_centerline(
            neighbours_traj, map_feature_helpers["ORACLE_CENTERLINE"], ref_t)
        if self.mode == "train":
            return_dict = {
                'seq_path': self.seq_paths[index],
                'train_traj': map_features[:self.train_seq_size, :],
                'gt_traj': map_features[self.train_seq_size:, :],
                'neighbours': neighbour_centerline_frame,
                'helpers': map_feature_helpers,
                'norm': norm,
                'ref_t': ref_t,
                'social_features': social_features
            }
        elif self.mode == "validate":
            return_dict = {
                'seq_path': self.seq_paths[index],
                'train_traj': map_features[:self.train_seq_size, :],
                'gt_traj': map_features[self.train_seq_size:, :],
                'gt_unnorm_traj': agent_traj[self.train_seq_size:, :],
                'neighbours': neighbour_centerline_frame,
                'helpers': map_feature_helpers,
                'norm': norm,
                'ref_t': ref_t,
                'social_features': social_features
            }
            if self.save:
                with open(
                        f"/home/scratch/nitinsin/argoverse_social_centerline/val/{index}.pkl",
                        'wb') as f:
                    pickle.dump(return_dict, f)
        return return_dict
Exemplo n.º 2
0
 def inverse_transform(self, trajectory, traj_dict):
     centerline = traj_dict['centerline']
     if self.use_cuda:
         trajectory = trajectory.cpu()
     out = get_xy_from_nt_seq(nt_seq=trajectory, centerlines=centerline)
     out = torch.Tensor(out).float()
     if self.use_cuda:
         out = out.cuda()
     return out
Exemplo n.º 3
0
    def __getitem__(self, index):
        if self.mode == "train" or self.mode == "validate":
            # import pdb;pdb.set_trace()
            if self.load_saved and self.mode == "train":
                with open(
                        f"/home/scratch/nitinsin/argoverse/train/{index}.pkl",
                        'rb') as f:
                    train_dict = pickle.load(f)
                return train_dict
            if self.load_saved and self.mode == "validate":
                with open(f"/home/scratch/nitinsin/argoverse/val/{index}.pkl",
                          'rb') as f:
                    val_dict = pickle.load(f)
                return val_dict
            current_loader = self.afl.get(self.seq_paths[index])
            agent_traj = current_loader.agent_traj
            social_features, map_features, map_feature_helpers = compute_features(
                self.seq_paths[index], self.map_features_utils_instance,
                self.social_features_utils_instance, self.avm, 'train')
            # social_features,map_features,map_feature_helpers = self.compute_features_old(
            #     self.seq_paths[index], None,None,None,'train')
            unnorm_traj = get_xy_from_nt_seq(
                np.expand_dims(map_features, axis=0),
                [map_feature_helpers["ORACLE_CENTERLINE"]])
            norm = np.linalg.norm(unnorm_traj - agent_traj)
            # if norm>1.0:
            #     print(f"Norm at index {index}",norm)
            ref_t = map_features[self.train_seq_size - 1, 1]
            map_features[:, 1] = map_features[:, 1] - ref_t
            if self.mode == "train":
                return_dict = {
                    'seq_path': self.seq_paths[index],
                    'train_traj': map_features[:self.train_seq_size, :],
                    'gt_traj': map_features[self.train_seq_size:, :],
                    'helpers': map_feature_helpers,
                    'norm': norm,
                    'ref_t': ref_t,
                    'social_features': social_features
                }
                if self.save:
                    with open(
                            f"/home/scratch/nitinsin/argoverse/train/{index}.pkl",
                            'wb') as f:
                        pickle.dump(return_dict, f)
            else:
                return_dict = {
                    'seq_path': self.seq_paths[index],
                    'train_traj': map_features[:self.train_seq_size, :],
                    'gt_unnorm_traj': agent_traj[self.train_seq_size:, :],
                    'helpers': map_feature_helpers,
                    'norm': norm,
                    'ref_t': ref_t,
                    'social_features': social_features
                }
                if self.save:
                    with open(
                            f"/home/scratch/nitinsin/argoverse/val/{index}.pkl",
                            'wb') as f:
                        pickle.dump(return_dict, f)

            return return_dict
            # return {'seq_path':self.seq_paths[index],'train_unnorm_traj': agent_traj[:self.train_seq_size,:],
            #         'train_traj':map_features[:self.train_seq_size,:],'gt_traj':map_features[self.train_seq_size:,:],
            #         'gt_unnorm_traj':agent_traj[self.train_seq_size:,:],'helpers':map_feature_helpers,
            #         'norm_traj':map_features,'unnorm_traj':agent_traj}
        elif self.mode == "validate_multiple":
            current_loader = self.afl.get(self.seq_paths[index])
            agent_traj = current_loader.agent_traj
            social_features, map_features, map_feature_helpers = compute_features(
                self.seq_paths[index], self.map_features_utils_instance,
                self.social_features_utils_instance, self.avm, 'test')
            return {
                'seq_path': self.seq_paths[index],
                'helpers': map_feature_helpers,
                'train_unnorm_traj': agent_traj[0:self.train_seq_size, :],
                'gt_unnorm_traj': agent_traj[self.train_seq_size:, :],
                'city': current_loader.city,
                'norm': 0.0
            }
        elif self.mode == "test":
            social_features, map_features, map_feature_helpers = compute_features(
                self.seq_paths[index], self.map_features_utils_instance,
                self.social_features_utils_instance, self.avm, 'test')
            return {
                'seq_path': self.seq_paths[index],
                'helpers': map_feature_helpers
            }
Exemplo n.º 4
0
                              collate_fn=collate_traj_multilane)
dataloader_val = DataLoader(dataset_val,
                            batch_size=64,
                            shuffle=False,
                            num_workers=8,
                            collate_fn=collate_traj_multilane)
all_correct_seq_path = []
# selected=0
total = 0
for i, traj_dict in enumerate(dataloader_train):
    gt_traj = traj_dict['gt_traj']
    gt_unnorm_traj = traj_dict['gt_unnorm_traj'].numpy()
    all_centerlines = [
        helper["ORACLE_CENTERLINE"] for helper in traj_dict['helpers']
    ]
    pred_unnorm_traj = get_xy_from_nt_seq(gt_traj.numpy(), all_centerlines)
    norm = np.linalg.norm(pred_unnorm_traj - gt_unnorm_traj, axis=(1, 2))
    index = norm < 0.05
    seq_paths = traj_dict['seq_path']

    all_correct_seq_path.extend(
        [seq_paths[i] for i in range(len(seq_paths)) if index[i] == True])
    total += len(seq_paths)
    print(f"{len(all_correct_seq_path)}/{total} selected for train", end="\r")
print()
with open("train.pkl", 'wb') as f:
    pickle.dump(all_correct_seq_path, f)

# for i,traj_dict in enumerate(dataloader_val):
#     gt_traj=traj_dict['gt_traj']
#     gt_unnorm_traj=traj_dict['gt_unnorm_traj'].numpy()