Beispiel #1
0
def get_multi_object_dataset():
    args = Settings()
    if args.dataset == 'NGSIM':
        trSet = MultiObjectNGSIMDataset(
            args.NGSIM_data_directory + 'TrainSet_traj_v2.mat',
            args.NGSIM_data_directory + 'TrainSet_tracks_v2.mat',
            args=args)
        valSet = MultiObjectNGSIMDataset(
            args.NGSIM_data_directory + 'ValSet_traj_v2.mat',
            args.NGSIM_data_directory + 'ValSet_tracks_v2.mat',
            args=args)
    elif args.dataset == 'Argoverse':
        trSet = MultiObjectArgoverseDataset(args.argoverse_data_directory +
                                            'train/data',
                                            args=args)
        valSet = MultiObjectArgoverseDataset(args.argoverse_data_directory +
                                             'val/data',
                                             args=args)
    elif args.dataset == 'Fusion':
        trSet = MultiObjectFusionDataset(args.fusion_data_directory +
                                         'train_sequenced_data.tar',
                                         args=args)
        valSet = MultiObjectFusionDataset(args.fusion_data_directory +
                                          'val_sequenced_data.tar',
                                          args=args)

    return trSet, valSet
Beispiel #2
0
    def __init__(self):
        self.args = Settings()
        self.field_height = self.args.field_height
        self.field_width = self.args.field_width
        self.pixel_per_meters = self.args.pixel_per_meters
        self.image_height = int(self.pixel_per_meters * self.field_height)
        self.image_width = int(self.pixel_per_meters * self.field_width)

        self.image_center = np.array(
            [self.image_width // 2, self.image_height // 2])
Beispiel #3
0
    def __init__(self, head_selection):
        self.head_selection = head_selection
        self.color_names = ['1 Blue', '2 Orange', '3 Green', '4 Red', '5 Purple', '6 Brown', '7 Pink', '8 Gray', '9 Olive', '10 Cyan']
        self.active_heads = []
        self.args = Settings()
        self.field_height = self.args.field_height
        self.field_width = self.args.field_width
        self.pixel_per_meters = 1#self.args.pixel_per_meters
        self.source_object = ColumnDataSource(data={'x': [], 'y': []})
        self.object_glyph = Circle(x='x', y='y', radius=0.5, fill_alpha=0.8, fill_color='red')
        self.source_vehicle = ColumnDataSource(data={'x': [], 'y': [], 'angle': [], 'width': [], 'height': []})

        self.source_vehicle.selected.on_change('indices', self._tap_on_veh)
        self.vehicle_glyph = Rect(x='x', y='y', angle='angle', width='width', height='height',
                                  fill_color='red', fill_alpha=0.8)
        self.image = None
        self.get_image()
        self.image_height = int(self.pixel_per_meters*self.field_height)
        self.image_width = int(self.pixel_per_meters*self.field_width)

        self.image_center = np.array([self.image_width//2, self.image_height//2])
        self.source_ellipse = []
        self.ellipse_glyph = Ellipse(x='x', y='y', width='width', height='height', angle='angle',
                                     fill_alpha='alpha', line_color=None, fill_color='blue')
        self.n_ellipses = 0

        self.source_lane = []
        self.lane_glyph = Line(x='x', y='y', line_color='gray', line_dash='dashed', line_width=3)
        self.n_lanes = 0

        self.source_arrow = [[], []]
        self.arrow_glyph = Scatter(x='x', y='y', angle='angle', size='size', marker='triangle', line_color=None)
        self.arrow_glyph_list = []

        self.source_path_fut = []
        self.source_path_pred = []
        self.source_path_past = []

        self.fut_path_glyph = Line(x='x', y='y', line_color='green', line_width=2)
        self.pred_path_glyph = Line(x='x', y='y', line_color='red', line_width=2)
        self.past_path_glyph = Line(x='x', y='y', line_color='gray', line_width=2)

        self.n_past = 0
        self.n_fut = 0
        self.n_pred = 0
        self.attention_matrix = None
Beispiel #4
0
 def __init__(self):
     st.title('Plot sample trajectories')
     self.args = Settings()
     self._get_net()
     self._get_dataset()
     self._get_filter()
     self.index = self._select_data()
     self.scene_plotter = ScenePlotter()
     self.draw_lanes = True
     self.draw_past = True
     self.draw_fut = True
     self.draw_pred = True
     self.draw_cov = True
     self._set_what_to_draw()
     self._draw_image(
         "Road scene",
         "This represents the road scene, input past observation and forecasting"
     )
Beispiel #5
0
def get_multi_object_test_set():
    args = Settings()
    if args.dataset == 'NGSIM':
        testSet = MultiObjectNGSIMDataset(
            args.NGSIM_test_data_directory + 'TestSet_traj_v2.mat',
            args.NGSIM_test_data_directory + 'TestSet_tracks_v2.mat', args)
    elif args.dataset == 'Fusion':
        testSet = MultiObjectFusionDataset(
            args.fusion_data_directory + 'test_sequenced_data.tar', args)
    elif args.dataset == 'Argoverse':
        testSet = MultiObjectArgoverseDataset(
            args.argoverse_data_directory + '/val/dataset2/', False, False,
            True)
    else:
        raise RuntimeError(
            'Multi object loader does not support other datasets than NGSIM and Fusion.'
        )
    return testSet
import pickle
import matplotlib.pyplot as plt
from matplotlib.legend_handler import HandlerPatch
from matplotlib.patches import Ellipse
from matplotlib.lines import Line2D
from matplotlib_scalebar.scalebar import ScaleBar

from utils.utils import Settings

def latex_string(data):
    out = ''
    for value in data:
        out += "& {:2.2f} ".format(value)
    return out

args = Settings()
# if args.dataset == 'NGSIM':
#     x_axis = 1
#     y_axis = 0
# else:
x_axis = 1
y_axis = 0

try:
    results = np.load('./results/' + args.load_name + '.npz')
    is_pickle = False
    print('loaded %s' % args.load_name)
except FileNotFoundError as err:
    try:
        results = pickle.load(open('./results/' + args.load_name + '.pickle', 'rb'))
        is_pickle = True
Beispiel #7
0
    def __init__(self):

        ### Methods
        self.args = Settings()
        self.index = None
        self.data_getter = None
        self.filter = None
        self._data = None
        self._model_type = None
        self._model_dir = self.args.models_path + 'unique_object/'
        self.controls = {}
        self.scene_plotter = ScenePlotter(self._set_head_selection)

        ### initialization of the interface

        ## Model type selector

        def update_select_net():
            if self._model_dir is not None:
                file_list = [fn for fn in os.listdir(self._model_dir) if os.path.isfile(os.path.join(self._model_dir, fn))]
                file_list.sort(key=lambda fn: os.stat(os.path.join(self._model_dir, fn)).st_mtime, reverse=True)
                file_list = [os.path.splitext(fn)[0] for fn in file_list]
                self.controls['net'].options = file_list
                # print(self._model_dir)
                # print('file_list')
                # print(file_list)
                if len(file_list) > 0:
                    self.controls['net'].value = file_list[0]
                else:
                    self.controls['net'].value = None

        def update_model_type():
            if self.controls['multi_mono_object'].active == 0:
                self._model_type = 'mono'
                self._model_dir = self.args.models_path + 'unique_object/'
            elif self.controls['multi_mono_object'].active == 1:
                self._model_type = 'multi_obj'
                self._model_dir = self.args.models_path + 'multi_objects/'
            elif self.controls['multi_mono_object'].active == 2:
                self._model_type = 'multi_pred'
                self._model_dir = self.args.models_path + 'multi_pred/'
            model_types = ['CV', 'CA', 'Bicycle', 'CV_LSTM', 'CA_LSTM', 'Bicycle_LSTM',  'nn_attention']
            existing_types = [type for type in model_types if os.path.isdir(self._model_dir + type)]
            self.controls['model_sub_type'].options = existing_types
            print('existing types')
            print(existing_types)
            if len(existing_types) > 0 and not self.controls['model_sub_type'].value in existing_types:
                self.controls['model_sub_type'].value = existing_types[0]
                return
            set_model_sub_type()
            update_select_net()

        def set_model_sub_type():
            if self.controls['model_sub_type'].value is not None:
                self._model_dir = self._model_dir + self.controls['model_sub_type'].value + '/'
                self.args.model_type = self.controls['model_sub_type'].value
            else:
                self._model_dir = None

        def update_multi_mono_object(attr, old, new):
            update_model_type()
            print(self._model_dir)
            self._set_data_getter()
            print('___')

        multi_mono_object = RadioButtonGroup(labels=["Mono-object", "Multi-objects", "Multi-pred"], active=1)
        self.controls['multi_mono_object'] = multi_mono_object
        multi_mono_object.on_change('active', update_multi_mono_object)

        ## Model sub type selector
        model_types = ['CV', 'CA', 'Bicycle', 'CV_LSTM', 'CA_LSTM', 'Bicycle_LSTM',  'nn_attention']
        model_sub_type = Select(title='Select model type:', value=model_types[3], options=model_types)
        self.controls['model_sub_type'] = model_sub_type
        model_sub_type.on_change('value', lambda att, old, new: update_model_type())

        ## Model selector
        select = Select(title="Select parameter file:", value=None, options=[])
        self.controls['net'] = select
        select.on_change('value', lambda att, old, new: self._set_data_getter())

        ## Select dataset to use
        dataset_list = ['Argoverse', 'Fusion', 'NGSIM']
        select = Select(title='Dataset:', value=dataset_list[0], options=dataset_list)
        self.controls['dataset'] = select
        select.on_change('value', lambda att, old, new: self._set_data_getter(change_index=True))

        ## Set what to draw
        checkbox_group = CheckboxGroup(
            labels=['Draw lanes', 'Draw history', 'Draw true future', 'Draw forecast', 'Draw forecast covariance'],
            active=[0, 1, 2, 3, 4])
        self.controls['check_box'] = checkbox_group
        checkbox_group.on_change('active',
                                 lambda att, old, new: (self._update_cov(), self._update_lanes(), self._update_path()))

        ## Set the number of pred
        n_pred = Slider(start=1, end=6, step=1, value=1, title='Number of prediction hypothesis')
        self.controls['n_pred'] = n_pred
        n_pred.on_change('value', lambda att, old, new: (self._update_cov(), self._update_path()))

        ## Sequence ID input
        text_input = TextInput(title="Sequence ID to plot:", value="Random")
        self.controls['sequence_id'] = text_input

        ## Head selection input
        multi_select_head = MultiSelect(title='Attention head multiple selection:',
                                             value=[], options=[])
        self.controls['Head_selection'] = multi_select_head
        multi_select_head.on_change('value', self.scene_plotter.set_active_heads)

        ## Refresh all sample
        button = Button(label="Refresh", button_type="success")
        self.controls['refresh'] = button
        button.on_click(
            lambda event: (self._set_index(), self._set_data()))
        # button.js_on_click(CustomJS(args=dict(p=self.image), code="""p.reset.emit()"""))

        update_multi_mono_object(None, None, None)

        ## Set the interface layout
        inputs = column(*(self.controls.values()), width=320, height=1000)
        inputs.sizing_mode = "fixed"
        lay = layout([[inputs, self.scene_plotter.get_image()]], sizing_mode="scale_both")
        curdoc().add_root(lay)

        self.scene_plotter._tap_on_veh('selected', [], [0])
Beispiel #8
0
 def __init__(self):
     self.args = Settings()
     self.field_height = self.args.field_height
     self.field_width = self.args.field_width
            print('nll', self.nll_test[object_class, indices])
            print('miss rate', self.miss_rate[object_class, indices])
            print('         ==========================')
        else:
            self._translate_object_class(object_class, self.print_stats,
                                         spacing)

    def plot_hist(self, object_class='all', spacing=1):
        self._compute_stats()
        if object_class in self.class_dict:
            indices = self._get_indices_at_spacing(spacing)
            # hist, bins = np.histogram(self.self.dist_error[object_class][indices, :].transpose(), bins=20)
            logbins = np.logspace(np.log10(1.e-2), np.log10(1000), 12)
            print('indices', indices)

            plt.hist(np.array(self.dist_error[object_class])[indices, :],
                     bins=logbins,
                     label=[str(int((i + 1) / 5)) for i in indices])
            plt.xscale('log')
            plt.legend()
            plt.show()
        else:
            self._translate_object_class(object_class, self.plot_hist, spacing)


stats = StatMultiObject(Settings())
# stats.print_stats('ego', 1)
# stats.print_stats('car')
# stats.print_stats('bicycle', 1)
stats.plot_hist('ego')
Beispiel #10
0
def get_multi_object_net():
    args = Settings()

    if args.model_type[-3:] == 'GRU':
        raise RuntimeError(
            'The action prediction using GRU have not been implemented for multi object data.'
        )
    if args.model_type == 'CV':
        net = MultiObjectKalman(args, CV_model)
    elif args.model_type == 'Bicycle':
        net = MultiObjectKalman(args, Bicycle_model)
    elif args.model_type == 'CA':
        net = MultiObjectKalman(args, CA_model)
    elif args.model_type == 'CV_LSTM':
        net = MultiObjectKalman(args, CV_LSTM_model)
    elif args.model_type == 'CA_LSTM':
        net = MultiObjectKalman(args, CA_LSTM_model)
    elif args.model_type == 'Bicycle_LSTM':
        net = MultiObjectKalman(args, Bicycle_LSTM_model)
    elif args.model_type == 'nn_attention':
        # currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
        currentdir = os.getcwd()
        module_dir = os.path.join(os.path.dirname(currentdir), 'Argoverse')
        sys.path.insert(0, module_dir)
        from attention_predictor import AttentionPredictor
        net = AttentionPredictor()
    # elif args.model_type == 'CV_GRU':
    #     net = CV_GRU_model(args)
    # elif args.model_type == 'CA_GRU':
    #     net = CA_GRU_model(args)
    # elif args.model_type == 'Bicycle_GRU':
    #     net = Bicycle_GRU_model(args)
    else:
        print('Model type ' + args.model_type + ' is not known.')

    net = net.to(args.device)

    if args.load_name != '':
        try:
            if args.model_type == 'nn_attention':
                net.load_state_dict(
                    torch.load('./trained_models/multi_pred/' +
                               args.model_type + '/' + args.load_name + '.tar',
                               map_location=args.device))
            else:
                net.load_state_dict(
                    torch.load('./trained_models/multi_objects/' +
                               args.model_type + '/' + args.load_name + '.tar',
                               map_location=args.device))
        except RuntimeError as err:
            print(err)
            print('Loading what can be loaded with option strict=False.')
            if args.model_type == 'nn_attention':
                net.load_state_dict(
                    torch.load('./trained_models/multi_pred/' +
                               args.model_type + '/' + args.load_name + '.tar',
                               map_location=args.device),
                    strict=False)
            else:
                net.load_state_dict(
                    torch.load('./trained_models/multi_objects/' +
                               args.model_type + '/' + args.load_name + '.tar',
                               map_location=args.device),
                    strict=False)
    return net