Beispiel #1
0
def compute_best_candidates(agent_track: np.ndarray, obs_len: int,
                            pred_len: int, raw_data_format: Dict[str, int]):

    agent_obs = agent_track[:obs_len]
    agent_xy_obs = agent_obs[:, [raw_data_format["X"], raw_data_format["Y"]
                                 ]].astype("float")

    xy = agent_xy_obs

    avm = ArgoverseMap()

    city_name = agent_track[0, raw_data_format["CITY_NAME"]]

    candidate_centerlines = avm.get_candidate_centerlines_for_traj(
        xy,
        city_name,
        viz=False,
        max_search_radius=_MAX_SEARCH_RADIUS_CENTERLINES,
    )

    cl_str = []
    cl_turn = []
    for cl in candidate_centerlines:
        if cl.shape[0] <= 10:
            continue
        straight = is_straight(cl)
        if straight == 1:
            cl_str.append(cl)
        else:
            cl_turn.append(cl)

    if len(cl_str) + len(cl_turn) == 0:
        idx_list = np.random.randint(len(candidate_centerlines), size=2)
        best_str = candidate_centerlines[idx_list[0]]
        best_turn = candidate_centerlines[idx_list[1]]

        return best_str, best_turn

    if len(cl_str) != 0:
        best_str = get_oracle_from_candidate_centerlines(cl_str, xy)

    if len(cl_turn) != 0:
        best_turn = get_oracle_from_candidate_centerlines(cl_turn, xy)

    if len(cl_str) == 0:
        best_str = best_turn

    if len(cl_turn) == 0:
        best_turn = best_str

    return best_str, best_turn
def compute_map_features(agent_track: np.ndarray, obs_len: int, pred_len: int,
                         raw_data_format: Dict[str, int], mode: str):

    agent_xy = agent_track[:, [raw_data_format["X"], raw_data_format["Y"]
                               ]].astype("float")
    agent_obs = agent_track[:obs_len]
    agent_xy_obs = agent_obs[:, [raw_data_format["X"], raw_data_format["Y"]
                                 ]].astype("float")

    if mode == "test":
        xy = agent_xy_obs
    else:
        xy = agent_xy

    avm = ArgoverseMap()

    city_name = agent_track[0, raw_data_format["CITY_NAME"]]

    candidate_centerlines = avm.get_candidate_centerlines_for_traj(
        xy,
        city_name,
        viz=False,
        max_search_radius=_MAX_SEARCH_RADIUS_CENTERLINES,
    )
    oracle_centerline = get_oracle_from_candidate_centerlines(
        candidate_centerlines, xy)

    oracle_nt_dist = get_nt_distance(xy, oracle_centerline, viz=False)

    oracle_nt_dist_norm = oracle_nt_dist - oracle_nt_dist[0, :]

    delta_ref = copy.deepcopy(oracle_nt_dist[0, :])
    for i in range(xy.shape[0] - 1, 0, -1):
        oracle_nt_dist[i, :] = oracle_nt_dist[i, :] - oracle_nt_dist[i - 1, :]
    oracle_nt_dist[0, :] = 0

    angle_w_cl = np.zeros((xy.shape[0], 1))
    angle_w_cl[1:, 0] = np.arctan2(oracle_nt_dist[1:, 1], oracle_nt_dist[1:,
                                                                         0])
    angle_w_cl[0, :] = angle_w_cl[1, :]
    #    angle_w_cl[np.isnan(angle_w_cl)] = np.pi/2

    map_features = np.concatenate((oracle_nt_dist_norm, angle_w_cl), axis=1)

    if mode == "test":
        map_features = np.concatenate(
            (map_features, np.full([pred_len, 3], None)), axis=0)

    return map_features, oracle_centerline, delta_ref
Beispiel #3
0
    def save_top_errors_accuracy(self, model_dir, model_path):
        subseq_len = args.subseq_len
        val_loss = 0
        mems = []
        num_batches = len(self.val_loader.batch_sampler)

        min_loss = np.inf
        max_loss = 0
        num_images = 10

        loss_list_max = []
        input_max_list = []
        pred_max_list = []
        target_max_list = []
        city_name_max = []
        seq_path_list_max = []

        loss_list_min = []
        input_min_list = []
        pred_min_list = []
        target_min_list = []
        city_name_min = []
        seq_path_list_min = []

        # self.model.load_state_dict(torch.load(model_path+'trellis-model.pt')['model_state_dict'])
        # self.model.eval()

        checkpoint = torch.load(model_path + 'trellis-model.pt',
                                map_location=torch.device('cpu'))
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.model.eval()

        for i_batch, traj_dict in enumerate(self.val_loader):
            print(f"Running {i_batch}/{num_batches}", end="\r")

            if mems:
                mems[0] = mems[0].detach()

            gt_traj = traj_dict['gt_unnorm_agent']
            target = traj_dict['gt_agent']
            train_traj = traj_dict['train_agent']

            if self.use_cuda:
                train_traj = train_traj.cuda()
                target = target.cuda()

            input_ = self.val_loader.dataset.inverse_transform(
                train_traj, traj_dict)
            (_, _, output), mems = self.model(train_traj,
                                              target,
                                              mems,
                                              train_step=self.train_step,
                                              f_thres=args.f_thres,
                                              b_thres=args.b_thres,
                                              subseq_len=subseq_len,
                                              decode=True)
            output = self.val_loader.dataset.inverse_transform(
                output, traj_dict)

            # if self.use_cuda:
            output = output.cpu()
            input_ = input_.cpu()
            target = target.cpu()

            loss = torch.norm(output.reshape(output.shape[0], -1) -
                              gt_traj.reshape(gt_traj.shape[0], -1),
                              dim=1)
            min_loss, min_index = torch.min(loss, dim=0)
            max_loss, max_index = torch.max(loss, dim=0)

            input_min_list.append(input_[min_index])
            pred_min_list.append(output[min_index])
            target_min_list.append(gt_traj[min_index])

            input_max_list.append(input_[max_index])
            pred_max_list.append(output[max_index])
            target_max_list.append(gt_traj[max_index])

            city_name_min.append(traj_dict['city'][min_index])
            city_name_max.append(traj_dict['city'][max_index])

            seq_path_list_max.append(traj_dict['seq_path'][max_index])
            seq_path_list_min.append(traj_dict['seq_path'][min_index])

            loss_list_max.append(min_loss.data)
            loss_list_min.append(max_loss.data)

        loss_list_max_array = np.array(loss_list_max)
        loss_list_max = list(loss_list_max_array.argsort()[-num_images:][::-1])

        loss_list_min_array = np.array(loss_list_min)
        loss_list_min = list(loss_list_min_array.argsort()[:num_images])

        avm = ArgoverseMap()

        high_error_path = model_dir + "/visualization/high_errors/"
        low_error_path = model_dir + "/visualization/low_errors/"

        if not os.path.exists(high_error_path):
            os.makedirs(high_error_path)

        if not os.path.exists(low_error_path):
            os.makedirs(low_error_path)

        input_max = []
        pred_max = []
        target_max = []
        city_max = []

        centerlines_max = []
        for i, index in enumerate(loss_list_max):
            print(f"Max: {i}")
            input_max.append(input_max_list[index].detach().numpy())
            pred_max.append([pred_max_list[index].detach().numpy()])
            target_max.append(target_max_list[index].detach().numpy())
            city_max.append(city_name_max[index])
            viz_sequence(df=pd.read_csv(seq_path_list_max[index]),
                         save_path=f"{high_error_path}/dataframe_{i}.png",
                         show=True,
                         avm=avm)
            centerlines_max.append(
                avm.get_candidate_centerlines_for_traj(input_max[-1],
                                                       city_max[-1],
                                                       viz=False))
        print("Created max array")

        input_min = []
        pred_min = []
        target_min = []
        city_min = []
        centerlines_min = []
        for i, index in enumerate(loss_list_min):
            print(f"Min: {i}")
            input_min.append(input_min_list[index].detach().numpy())
            pred_min.append([pred_min_list[index].detach().numpy()])
            target_min.append(target_min_list[index].detach().numpy())
            city_min.append(city_name_min[index])
            # seq_path_min.append(seq_path_list_min[index])
            viz_sequence(df=pd.read_csv(seq_path_list_min[index]),
                         save_path=f"{low_error_path}/dataframe_{i}.png",
                         show=True,
                         avm=avm)
            centerlines_min.append(
                avm.get_candidate_centerlines_for_traj(input_min[-1],
                                                       city_min[-1],
                                                       viz=False))
        print("Created min array")

        print(f"Saving max visualizations at {high_error_path}")
        viz_predictions(input_=np.array(input_max),
                        output=pred_max,
                        target=np.array(target_max),
                        centerlines=centerlines_max,
                        city_names=np.array(city_max),
                        avm=avm,
                        save_path=high_error_path)

        print(f"Saving min visualizations at {low_error_path}")
        viz_predictions(input_=np.array(input_min),
                        output=pred_min,
                        target=np.array(target_min),
                        centerlines=centerlines_min,
                        city_names=np.array(city_min),
                        avm=avm,
                        save_path=low_error_path)
Beispiel #4
0
class Argoverse_MultiLaneCentre_Data(Argoverse_Data):
    def __init__(self,
                 root_dir='argoverse-data//data',
                 avm=None,
                 social=False,
                 train_seq_size=20,
                 cuda=False,
                 test=False,
                 oracle=False):
        super(Argoverse_LaneCentre_Data,
              self).__init__(root_dir, train_seq_size, cuda, test)
        if avm is None:
            self.avm = ArgoverseMap()
        else:
            self.avm = avm
        self.stationary_threshold = 2.0
        self.oracle = oracle
        print("Done loading map")

    def __len__(self):
        # return 10000
        return len(self.seq_paths)

    def inverse_transform(self, trajectory, traj_dict):
        centerline = traj_dict['centerline']
        if self.use_cuda:
            trajectory = trajectory.cpu()
        out = get_xy_from_nt_seq(nt_seq=trajectory, centerlines=centerline)
        out = torch.Tensor(out).float()
        if self.use_cuda:
            out = out.cuda()
        return out

    def __getitem__(self, index):
        current_loader = self.afl.get(self.seq_paths[index])
        agent_traj = current_loader.agent_traj
        candidate_centerlines = self.avm.get_candidate_centerlines_for_traj(
            agent_traj, current_loader.city, viz=False)
        if self.oracle:
            candidate_centerlines = [
                get_oracle_from_candidate_centerlines(candidate_centerlines,
                                                      agent_traj)
            ]
        if self.mode_test:
            seq_index = int(
                os.path.basename(self.seq_paths[index]).split('.')[0])

            agent_train_traj = agent_traj[:self.train_seq_size, :]
            all_centerline_traj = []
            for centerline in candidate_centerlines:
                all_centerline_traj.append(
                    torch.Tensor(
                        get_nt_distance(agent_train_traj,
                                        current_centerline)).float())

            return {
                'seq_index': seq_index,
                'train_agent': all_centerline_traj,
                'centerline': candidate_centerlines,
                'city': current_loader.city
            }

        else:
            agent_train_traj = agent_traj[:self.train_seq_size, :]
            agent_gt_traj = agent_traj[self.train_seq_size:, ]
            all_centerline_train_traj = []
            all_centerline_gt_traj = []
            for centerline in candidate_centerlines:
                all_centerline_train_traj.append(
                    torch.Tensor(
                        get_nt_distance(agent_train_traj,
                                        current_centerline)).float())
                all_centerline_gt_traj.append(
                    torch.Tensor(
                        get_nt_distance(agent_gt_traj,
                                        current_centerline)).float())

            agent_unnorm_gt_traj = torch.Tensor(
                agent_traj[self.train_seq_size:, ]).float()

            return {
                'train_agent': all_centerline_train_traj,
                'gt_agent': all_centerline_gt_traj,
                'gt_unnorm_agent': agent_unnorm_gt_traj,
                'centerline': current_centerline,
                'city': current_loader.city
            }
def test_get_candidate_centerlines_for_traj():
    """Test get_candidate_centerlines_for_traj()

    -180        .  .  .  .  .                   -100
                                                            2340
                            v
                            |                                 .
                            |                                 .
                            *  (CL1)                          .
                             \
                              \
                        (CL2)  \
                      >---------*-------------------->
                        s x x x x x x x x x e                 .
        >-------------------------------------------->        .
                        (CL3)                               2310
    """
    xy = np.array([[-130.0, 2315.0], [-129.0, 2315.0], [-128.0, 2315.0],
                   [-127, 2315], [-126, 2315], [-125, 2315], [-124, 2315]])
    city_name = "MIA"
    avm = ArgoverseMap()
    # import pdb; pdb.set_trace()
    candidate_centerlines = avm.get_candidate_centerlines_for_traj(
        xy, city_name)

    assert len(candidate_centerlines) == 3, "Number of candidates wrong!"

    expected_centerlines = [
        np.array([
            [-131.88540689, 2341.87225878],
            [-131.83054027, 2340.33723194],
            [-131.77567365, 2338.8022051],
            [-131.72080703, 2337.26717826],
            [-131.66594041, 2335.73215142],
            [-131.61107379, 2334.19712458],
            [-131.55620718, 2332.66209774],
            [-131.50134056, 2331.1270709],
            [-131.44647394, 2329.59204406],
            [-131.39160732, 2328.05701721],
            [-131.39160732, 2328.05701721],
            [-131.37997138, 2327.72338427],
            [-131.36833545, 2327.38975132],
            [-131.35669951, 2327.05611837],
            [-131.34506358, 2326.72248542],
            [-131.33342764, 2326.38885247],
            [-131.32179171, 2326.05521952],
            [-131.31015577, 2325.72158657],
            [-131.29851984, 2325.38795362],
            [-131.2868839, 2325.05432067],
            [-131.2868839, 2325.05432067],
            [-131.19279519, 2322.55119928],
            [-130.98376304, 2320.05690639],
            [-130.24692629, 2317.70490846],
            [-128.37426431, 2316.09358878],
            [-125.9878693, 2315.38876171],
            [-123.48883479, 2315.29784077],
            [-120.98715427, 2315.43423973],
            [-118.48467829, 2315.55478278],
            [-115.9822023, 2315.67532583],
            [-115.9822023, 2315.67532583],
            [-114.27604136, 2315.74436169],
            [-112.56988042, 2315.81339756],
            [-110.86371948, 2315.88243342],
            [-109.15755854, 2315.95146928],
            [-107.4513976, 2316.02050515],
            [-105.74523665, 2316.08954101],
            [-104.03907571, 2316.15857687],
            [-102.33291477, 2316.22761274],
            [-100.62675383, 2316.2966486],
        ]),
        np.array([
            [-139.13361714, 2314.54725812],
            [-136.56123771, 2314.67259898],
            [-133.98885829, 2314.79793983],
            [-131.41647886, 2314.92328069],
            [-128.84409943, 2315.04862155],
            [-126.27172001, 2315.1739624],
            [-123.69934058, 2315.29930326],
            [-121.12696116, 2315.42464412],
            [-118.55458173, 2315.54998497],
            [-115.9822023, 2315.67532583],
            [-115.9822023, 2315.67532583],
            [-114.27604136, 2315.74436169],
            [-112.56988042, 2315.81339756],
            [-110.86371948, 2315.88243342],
            [-109.15755854, 2315.95146928],
            [-107.4513976, 2316.02050515],
            [-105.74523665, 2316.08954101],
            [-104.03907571, 2316.15857687],
            [-102.33291477, 2316.22761274],
            [-100.62675383, 2316.2966486],
        ]),
        np.array([
            [-178.94773558, 2309.75038731],
            [-175.73132051, 2309.8800903],
            [-172.51490545, 2310.00979328],
            [-169.29849039, 2310.13949626],
            [-166.08207532, 2310.26919925],
            [-162.86566026, 2310.39890223],
            [-159.64924519, 2310.52860522],
            [-156.43283013, 2310.6583082],
            [-153.21641506, 2310.78801118],
            [-150.0, 2310.91771417],
            [-150.0, 2310.91771417],
            [-148.77816698, 2310.97013154],
            [-147.55633396, 2311.0225489],
            [-146.33450094, 2311.07496627],
            [-145.11266792, 2311.12738364],
            [-143.89083489, 2311.17980101],
            [-142.66900187, 2311.23221837],
            [-141.44716885, 2311.28463574],
            [-140.22533583, 2311.33705311],
            [-139.00350281, 2311.38947048],
            [-139.00350281, 2311.38947048],
            [-136.42679274, 2311.51113082],
            [-133.85008268, 2311.63279117],
            [-131.27337261, 2311.75445152],
            [-128.69666254, 2311.87611187],
            [-126.11995247, 2311.99777222],
            [-123.54324241, 2312.11943257],
            [-120.96653234, 2312.24109292],
            [-118.38982227, 2312.36275327],
            [-115.8131122, 2312.48441361],
            [-115.8131122, 2312.48441361],
            [-114.11040334, 2312.54102742],
            [-112.40815545, 2312.6106056],
            [-110.70605773, 2312.68440659],
            [-109.00396, 2312.75820759],
            [-107.30186227, 2312.83200858],
            [-105.59976454, 2312.90580958],
            [-103.89766681, 2312.97961057],
            [-102.19556909, 2313.05341156],
            [-100.49347136, 2313.12721256],
        ]),
    ]

    for i in range(len(expected_centerlines)):
        assert np.allclose(
            expected_centerlines[i],
            candidate_centerlines[i]), "Centerline coordinates wrong!"
class Argoverse_MultiLane_Data(Argoverse_Data):
    def __init__(self,
                 root_dir='argoverse-data//data',
                 avm=None,
                 train_seq_size=20,
                 mode="train",
                 save=False,
                 load_saved=False):
        super(Argoverse_MultiLane_Data, self).__init__(root_dir,
                                                       train_seq_size)
        if avm is None:
            self.avm = ArgoverseMap()
        else:
            self.avm = avm
        # if mode=="train":
        #     with open('train.pkl', 'rb') as f:
        #         self.seq_paths=pickle.load(f)
        # elif mode=="validate":
        #     with open('val.pkl', 'rb') as f:
        #         self.seq_paths=pickle.load(f)
        self.map_features_utils_instance = MapFeaturesUtils()
        self.social_features_utils_instance = SocialFeaturesUtils()
        self.mode = mode
        self.save = save
        self.load_saved = load_saved

    def compute_features_old(self,
                             seq_path,
                             map_instance,
                             social_feature_instance,
                             avm,
                             mode="train"):
        check1 = True
        if check1:
            if mode == "train" or mode == "validate":
                current_loader = self.afl.get(seq_path)
                agent_traj = current_loader.agent_traj
                # df = pd.read_csv(seq_path, dtype={"TIMESTAMP": str})
                # agent_track = df[df["OBJECT_TYPE"] == "AGENT"].values
                candidate_centerlines = self.avm.get_candidate_centerlines_for_traj(
                    agent_traj, current_loader.city, viz=False)
                current_centerline = get_oracle_from_candidate_centerlines(
                    candidate_centerlines, agent_traj)
                agent_traj_norm = get_nt_distance(agent_traj,
                                                  current_centerline)
                return None, agent_traj_norm, {
                    "ORACLE_CENTERLINE": current_centerline
                }
            elif mode == "validate_multiple":
                current_loader = self.afl.get(seq_path)
                agent_traj = current_loader.agent_traj
                candidate_centerlines = self.avm.get_candidate_centerlines_for_traj(
                    agent_traj, current_loader.city, viz=False)
        else:
            map_features, map_feature_helpers = self.map_features_utils_instance.compute_map_features(
                agent_track, 20, 50, RAW_DATA_FORMAT, mode, avm)
            return None, map_features, map_feature_helpers

    def __getitem__(self, index):
        if self.mode == "train" or self.mode == "validate":
            # import pdb;pdb.set_trace()
            if self.load_saved and self.mode == "train":
                with open(
                        f"/home/scratch/nitinsin/argoverse/train/{index}.pkl",
                        'rb') as f:
                    train_dict = pickle.load(f)
                return train_dict
            if self.load_saved and self.mode == "validate":
                with open(f"/home/scratch/nitinsin/argoverse/val/{index}.pkl",
                          'rb') as f:
                    val_dict = pickle.load(f)
                return val_dict
            current_loader = self.afl.get(self.seq_paths[index])
            agent_traj = current_loader.agent_traj
            social_features, map_features, map_feature_helpers = compute_features(
                self.seq_paths[index], self.map_features_utils_instance,
                self.social_features_utils_instance, self.avm, 'train')
            # social_features,map_features,map_feature_helpers = self.compute_features_old(
            #     self.seq_paths[index], None,None,None,'train')
            unnorm_traj = get_xy_from_nt_seq(
                np.expand_dims(map_features, axis=0),
                [map_feature_helpers["ORACLE_CENTERLINE"]])
            norm = np.linalg.norm(unnorm_traj - agent_traj)
            # if norm>1.0:
            #     print(f"Norm at index {index}",norm)
            ref_t = map_features[self.train_seq_size - 1, 1]
            map_features[:, 1] = map_features[:, 1] - ref_t
            if self.mode == "train":
                return_dict = {
                    'seq_path': self.seq_paths[index],
                    'train_traj': map_features[:self.train_seq_size, :],
                    'gt_traj': map_features[self.train_seq_size:, :],
                    'helpers': map_feature_helpers,
                    'norm': norm,
                    'ref_t': ref_t,
                    'social_features': social_features
                }
                if self.save:
                    with open(
                            f"/home/scratch/nitinsin/argoverse/train/{index}.pkl",
                            'wb') as f:
                        pickle.dump(return_dict, f)
            else:
                return_dict = {
                    'seq_path': self.seq_paths[index],
                    'train_traj': map_features[:self.train_seq_size, :],
                    'gt_unnorm_traj': agent_traj[self.train_seq_size:, :],
                    'helpers': map_feature_helpers,
                    'norm': norm,
                    'ref_t': ref_t,
                    'social_features': social_features
                }
                if self.save:
                    with open(
                            f"/home/scratch/nitinsin/argoverse/val/{index}.pkl",
                            'wb') as f:
                        pickle.dump(return_dict, f)

            return return_dict
            # return {'seq_path':self.seq_paths[index],'train_unnorm_traj': agent_traj[:self.train_seq_size,:],
            #         'train_traj':map_features[:self.train_seq_size,:],'gt_traj':map_features[self.train_seq_size:,:],
            #         'gt_unnorm_traj':agent_traj[self.train_seq_size:,:],'helpers':map_feature_helpers,
            #         'norm_traj':map_features,'unnorm_traj':agent_traj}
        elif self.mode == "validate_multiple":
            current_loader = self.afl.get(self.seq_paths[index])
            agent_traj = current_loader.agent_traj
            social_features, map_features, map_feature_helpers = compute_features(
                self.seq_paths[index], self.map_features_utils_instance,
                self.social_features_utils_instance, self.avm, 'test')
            return {
                'seq_path': self.seq_paths[index],
                'helpers': map_feature_helpers,
                'train_unnorm_traj': agent_traj[0:self.train_seq_size, :],
                'gt_unnorm_traj': agent_traj[self.train_seq_size:, :],
                'city': current_loader.city,
                'norm': 0.0
            }
        elif self.mode == "test":
            social_features, map_features, map_feature_helpers = compute_features(
                self.seq_paths[index], self.map_features_utils_instance,
                self.social_features_utils_instance, self.avm, 'test')
            return {
                'seq_path': self.seq_paths[index],
                'helpers': map_feature_helpers
            }
#                            [0.9, 96],
#                            [0.5, 97],
#                            [0.6, 0.2],
#                            [0.8, 0.1],
#                            [99, 0.6]]))
# print(m)
# print(n)
# criteria = torch.nn.MSELoss()
# l = criteria(n, m)
# print(l)
#
# print(n[torch.where(torch.isnan(m))])
# p = torch.where(torch.isnan(m), n, m)
# print(p)
# q = torch.where(torch.isnan(m), torch.full_like(m, 0), m)
# print(q)
#
#
# l = criteria(n, p)
# print(l)
# l = criteria(n, q)
# print(l)

## ================以下是对于get_candidate_centerlines_for_traj方法的测试===========
am = ArgoverseMap()  # map 操作对象
traj = np.array([[3181,1671],[3178,1675],[3176,1680],[3170,1682]])
altraj = am.get_candidate_centerlines_for_traj(traj,'PIT')



Beispiel #8
0
    target_array=[]
    city_names=[]
    centerlines=[]
    
    with open(path1, 'rb') as f:
        dict1=pickle.load(f)
    with open(path2, 'rb') as f:
        dict2=pickle.load(f)
    with open(path3, 'rb') as f:
        dict3=pickle.load(f)
    with open(path4, 'rb') as f:
        dict4=pickle.load(f)
    seq_index1=dict1['seq_index']
    seq_index2=dict2['seq_index']
    seq_index3=dict2['seq_index']
    seq_index4=dict2['seq_index']
    if seq_index1!=seq_index2 and seq_index3!=seq_index4 and seq_index1!=seq_index2:
        print("Something is wrong")
        exit()
    input_array.extend([dict1['input'],dict2['input'],dict3['input'],dict4['input']])
    target_array.extend([dict1['target'],dict2['target'],dict3['target'],dict4['target']])
    pred_array.extend([[dict1['output']],[dict2['output']],[dict3['output']],[dict4['output']]])
    city_names.extend([dict1['city'],dict2['city'],dict3['city'],dict4['city']])

    centerlines.append(avm.get_candidate_centerlines_for_traj(dict1['input'], dict1['city'],viz=False))
    centerlines.append(avm.get_candidate_centerlines_for_traj(dict2['input'], dict2['city'],viz=False))
    centerlines.append(avm.get_candidate_centerlines_for_traj(dict3['input'], dict3['city'],viz=False))
    centerlines.append(avm.get_candidate_centerlines_for_traj(dict4['input'], dict4['city'],viz=False))
    # import pdb;pdb.set_trace()
    viz_predictions(input_=np.array(input_array), output=pred_array,target=np.array(target_array),
                    centerlines=centerlines,city_names=np.array(city_names),avm=avm,save_path=f"{output_path}/{seq_index1}.png")
Beispiel #9
0
    def save_top_errors_accuracy_single_pred(self):
        afl=ArgoverseForecastingLoader("data/val/data/")
        self.model.load_state_dict(torch.load(self.model_dir+'best-model.pt')['model_state_dict'])
        self.model.eval()
        min_loss=np.inf
        max_loss=0
        num_images=10
        
        loss_list_max=[]
        input_max_list=[]
        pred_max_list=[]
        target_max_list=[]
        city_name_max=[]
        seq_path_list_max=[]
        
        loss_list_min=[]
        input_min_list=[]
        pred_min_list=[]
        target_min_list=[]
        city_name_min=[]
        seq_path_list_min=[]

        num_batches=len(self.val_loader.batch_sampler)
        for i_batch,traj_dict in enumerate(self.val_loader):
            print(f"Running {i_batch}/{num_batches}",end="\r")
            # pdb.set_trace()
            # pred_traj=self.model(traj_dict)
            # pred_traj=self.val_loader.dataset.inverse_transform(pred_traj,traj_dict)
            gt_traj=traj_dict['gt_unnorm_traj']
            output=self.model(traj_dict,mode='validate')
            output=output.cpu()
            loss=torch.norm(output.reshape(output.shape[0],-1)-gt_traj.reshape(gt_traj.shape[0],-1),dim=1)
            min_loss,min_index=torch.min(loss,dim=0)
            max_loss,max_index=torch.max(loss,dim=0)
            print(f"Min loss: {min_loss}, Max loss: {max_loss}" )

            
            seq_path_list_max.append(traj_dict['seq_path'][max_index])
            seq_path_list_min.append(traj_dict['seq_path'][min_index])

            loader_max=afl.get(traj_dict['seq_path'][max_index])
            input_max_list.append(loader_max.agent_traj[0:20,:])
            city_name_max.append(loader_max.city)
            del loader_max
            
            loader_min=afl.get(traj_dict['seq_path'][min_index])
            input_min_list.append(loader_min.agent_traj[0:20,:])
            city_name_min.append(loader_min.city)
            del loader_min
            
            
            pred_min_list.append(output[min_index])
            target_min_list.append(gt_traj[min_index])

            
            pred_max_list.append(output[max_index])
            target_max_list.append(gt_traj[max_index])
            
            
            

            # loss_list_max.append(min_loss.data)
            # loss_list_min.append(max_loss.data)

            loss_list_max.append(max_loss.data)
            loss_list_min.append(min_loss.data)
            # torch.cuda.empty_cache()
        # pdb.set_trace()
        loss_list_max_array=np.array(loss_list_max)
        loss_list_max=list(loss_list_max_array.argsort()[-num_images:][::-1])

        loss_list_min_array=np.array(loss_list_min)
        loss_list_min=list(loss_list_min_array.argsort()[:num_images])

        # pdb.set_trace()

        avm=ArgoverseMap()
        
        high_error_path=model_dir+"/visualization/high_errors/"
        low_error_path=model_dir+"/visualization/low_errors/"

        if not os.path.exists(high_error_path):
            os.makedirs(high_error_path)

        if not os.path.exists(low_error_path):
            os.makedirs(low_error_path)

        # if self.use_cuda:
        #     input_=input_.cpu()
        #     output=output.cpu()

        input_max=[]
        pred_max=[]
        target_max=[]
        city_max=[]
        # import pdb;pdb.set_trace()
        # seq_path_max=[]
        centerlines_max=[]
        for i,index in enumerate(loss_list_max):
            print(f"Max: {i}")
            # pdb.set_trace()
            input_max.append(input_max_list[index])
            pred_max.append([pred_max_list[index].detach().numpy()])
            print(f"Difference in predicted and input_traj for maximum at {i} is {np.linalg.norm(input_max[-1]-pred_max[-1][0][0:20,:])}")
            target_max.append(target_max_list[index].detach().numpy())
            city_max.append(city_name_max[index])
            viz_sequence(df=pd.read_csv(seq_path_list_max[index]) ,save_path=f"{high_error_path}/dataframe_{i}.png",show=True,avm=avm)
            centerlines_max.append(avm.get_candidate_centerlines_for_traj(input_max[-1], city_max[-1],viz=False))
        print("Created max array")
        input_min=[]
        pred_min=[]
        target_min=[]
        city_min=[]
        # seq_path_min=[]
        centerlines_min=[]
        for i,index in enumerate(loss_list_min):
            # pdb.set_trace()
            print(f"Min: {i}")
            input_min.append(input_min_list[index])
            pred_min.append([pred_min_list[index].detach().numpy()])
            print(f"Difference in predicted and input_traj for minimum at {i} is {np.linalg.norm(input_min[-1]-pred_min[-1][0][0:20,:])}")
            target_min.append(target_min_list[index].detach().numpy())
            city_min.append(city_name_min[index])
            # seq_path_min.append(seq_path_list_min[index])
            viz_sequence(df=pd.read_csv(seq_path_list_min[index]) ,save_path=f"{low_error_path}/dataframe_{i}.png",show=True,avm=avm)
            centerlines_min.append(avm.get_candidate_centerlines_for_traj(input_min[-1], city_min[-1],viz=False))
        print("Created min array")
        print(f"Saving max visualizations at {high_error_path}")
        # import pdb;pdb.set_trace()
        viz_predictions(input_=np.array(input_max), output=pred_max,target=np.array(target_max),centerlines=centerlines_max,city_names=np.array(city_max),avm=avm,save_path=high_error_path)
        
        print(f"Saving min visualizations at {low_error_path}")
        # viz_predictions(input_=pred_min[0], output=pred_min,target=np.array(target_min),centerlines=centerlines_min,city_names=np.array(city_min),avm=avm,save_path=low_error_path)
        # viz_predictions(input_=np.expand_dims(input_min[0],axis=0),output=pred_min[0],target=np.expand_dims(target_min[0],axis=0),centerlines=[centerlines_min], city_names=np.expand_dims(city_min[0],axis=0),avm=avm,save_path=low_error_path)
        viz_predictions(input_=np.array(input_min), output=pred_min,target=np.array(target_min),centerlines=centerlines_min,city_names=np.array(city_min),avm=avm,save_path=low_error_path)
Beispiel #10
0
    def save_top_accuracy(self):
        print("running save accuracy")
        self.model.load_state_dict(torch.load(self.model_dir+'best-model.pt')['model_state_dict'])
        self.model.eval()

        min_loss=np.inf
        max_loss=0
        num_images=10
        
        loss_list_max=[]
        input_max_list=[]
        pred_max_list=[]
        target_max_list=[]
        city_name_max=[]
        seq_path_list_max=[]

        loss_list_min=[]
        input_min_list=[]
        pred_min_list=[]
        target_min_list=[]
        city_name_min=[]
        seq_path_list_min=[]

        num_batches=len(self.multi_val_loader.batch_sampler)
        for i_batch,traj_dict in enumerate(self.multi_val_loader):
            print(f"Running {i_batch}/{num_batches}",end="\r")
            gt_traj=traj_dict['gt_unnorm_traj'].numpy()
            pred_traj=self.model(traj_dict,mode='validate_multiple')
            loss=[]
            # import pdb;pdb.set_trace()
            for index in range(len(pred_traj)):
                loss_temp=[]
                for j in range(pred_traj[index].shape[0]):
                    loss_temp.append(np.linalg.norm(pred_traj[index][j]- gt_traj[index]))
                # import pdb;pdb.set_trace()
                loss.append(min(loss_temp))
            # import pdb;pdb.set_trace()
            loss=torch.Tensor(loss).float()
            min_loss,min_index=torch.min(loss,dim=0)
            max_loss,max_index=torch.max(loss,dim=0)

            
            
            input_min_list.append(traj_dict['train_unnorm_traj'][min_index])
            pred_min_list.append(pred_traj[min_index])
            target_min_list.append(traj_dict['gt_unnorm_traj'][min_index])


            input_max_list.append(traj_dict['train_unnorm_traj'][max_index])
            pred_max_list.append(pred_traj[max_index])
            target_max_list.append(traj_dict['gt_unnorm_traj'][max_index])
            
            city_name_min.append(traj_dict['city'][min_index])
            city_name_max.append(traj_dict['city'][max_index])

            seq_path_list_max.append(traj_dict['seq_path'][max_index])
            seq_path_list_min.append(traj_dict['seq_path'][min_index])

            loss_list_max.append(min_loss.data)
            loss_list_min.append(max_loss.data)
           
        
        loss_list_max_array=np.array(loss_list_max)
        loss_list_max=list(loss_list_max_array.argsort()[-num_images:][::-1])

        loss_list_min_array=np.array(loss_list_min)
        loss_list_min=list(loss_list_min_array.argsort()[:num_images])

        avm=ArgoverseMap()
        
        high_error_path=self.model_dir+"/visualization/high_errors/"
        low_error_path=self.model_dir+"/visualization/low_errors/"

        if not os.path.exists(high_error_path):
            os.makedirs(high_error_path)

        if not os.path.exists(low_error_path):
            os.makedirs(low_error_path)
        input_max=[]
        pred_max=[]
        target_max=[]
        city_max=[]
        centerlines_max=[]
        for i,index in enumerate(loss_list_max):
            print(f"Max: {i}")
            input_max.append(input_max_list[index].numpy())
            pred_max.append(pred_max_list[index])
            target_max.append(target_max_list[index].numpy())
            city_max.append(city_name_max[index])
            viz_sequence(df=pd.read_csv(seq_path_list_max[index]) ,save_path=f"{high_error_path}/dataframe_{i}.png",show=True,avm=avm)
            centerlines_max.append(avm.get_candidate_centerlines_for_traj(input_max[-1], city_max[-1],viz=False))
        print("Created max array")
        input_min=[]
        pred_min=[]
        target_min=[]
        city_min=[]
        centerlines_min=[]
        for i,index in enumerate(loss_list_min):
            print(f"Min: {i}")
            input_min.append(input_min_list[index].numpy())
            pred_min.append(pred_min_list[index])
            target_min.append(target_min_list[index].numpy())
            city_min.append(city_name_min[index])
            viz_sequence(df=pd.read_csv(seq_path_list_min[index]) ,save_path=f"{low_error_path}/dataframe_{i}.png",show=True,avm=avm)
            centerlines_min.append(avm.get_candidate_centerlines_for_traj(input_min[-1], city_min[-1],viz=False))
        import pdb;pdb.set_trace()
        print("Created min array")
        print(f"Saving max visualizations at {high_error_path}")
        viz_predictions(input_=np.array(input_max), output=pred_max,target=np.array(target_max),centerlines=centerlines_max,city_names=np.array(city_max),avm=avm,save_path=high_error_path)
        print(f"Saving min visualizations at {low_error_path}")
        viz_predictions(input_=np.array(input_min), output=pred_min,target=np.array(target_min),centerlines=centerlines_min,city_names=np.array(city_min),avm=avm,save_path=low_error_path)