# debug_policy_plot()
            # if step % 100000 == 0:
            #         recorder.save(hp.this_run_path+recorder_file)
    # recorder.save(hp.this_run_path+recorder_file)
    print('mean: ', np.mean(reward_and_rewards_list,axis=0))
    print('std dev: ', np.std(reward_and_rewards_list,axis=0))
    print('std error: ', np.std(reward_and_rewards_list,axis=0)/np.sqrt(len(reward_and_rewards_list)))



if __name__ == "__main__":

    recorder = Recorder(n=6)


    sensor = syc.Sensor( fisheye=fy_dict,centralwinx=32,centralwiny=32)
    saccade_agent = syc.Saccadic_Agent()

    reward = syc.Rewards(reward_types=['network'],relative_weights=[100.0])
    # observation_size = sensor.hp.winx*sensor.hp.winy*2
    saccade_observation_size = 64*64+hp.drift_state_size
    # saccade_RL = DeepQNetwork(np.prod(saccade_agent.max_q), saccade_observation_size,
    saccade_RL=DeepQNetwork(64*64, saccade_observation_size,
                      n_features_shaped=list(np.shape(sensor.dvs_view))+[1],
                      shape_fun= None,
                      reward_decay=0.99,
                      replace_target_iter=10,
                      memory_size=100000,
                      e_greedy_increment=0.0001,
                      learning_rate=0.0025,
                      double_q=True,
    #         images[ii]=-image+np.max(image)
    # images = prep_mnist_sparse_images(400,images_per_scene=20)
    images = read_images_from_path(
        '/home/bnapp/arivkindNet/video_datasets/stills_from_videos/some100img_from20bn/*',
        max_image=200)
    # images = [images[1]]
    # images = [np.sum(1.0*uu, axis=2) for uu in images]
    # images = [cv2.resize(uu, dsize=(256, 256-64), interpolation=cv2.INTER_AREA) for uu in images]
    if hp.logmode:
        images = [np.log10(uu + 1.0) for uu in images]

    # with open('../video_datasets/liron_images/shuffled_images.pkl', 'rb') as f:
    #     images = pickle.load(f)

    scene = syc.Scene(frame_list=images)
    sensor = syc.Sensor(log_mode=False, log_floor=1.0)
    agent = syc.Agent(
        max_q=[scene.maxx - sensor.hp.winx, scene.maxy - sensor.hp.winy])

    reward = syc.Rewards(
        reward_types=['central_rms_intensity', 'speed', 'saccade'],
        relative_weights=[1.0, -float(sys.argv[1]), 0])
    # observation_size = sensor.hp.winx*sensor.hp.winy*2
    observation_size = 64 * 64 + 2
    RL = DeepQNetwork(len(agent.hp.action_space),
                      observation_size,
                      n_features_shaped=list(np.shape(sensor.dvs_view)) + [1],
                      shape_fun=None,
                      reward_decay=0.99,
                      e_greedy=0.95,
                      e_greedy0=0.8,
    if add_seed:
        np.random.seed(random.randint(0,add_seed))    
    orig_img = img*1
    #Set the padded image
    img=misc.build_cifar_padded(1./256*img)
    img_size = img.shape
    if img_num == 42:
        print('Are we Random?? ', np.random.randint(1,20))
    if show_fig:
        if count < 5:
            ax[0,i].imshow(orig_img) 
            plt.title(labels[count])
    #Set the sensor and the agent
    scene = syc.Scene(image_matrix=img)
    if up_sample:
        sensor = syc.Sensor(winx=52,winy=52,centralwinx=32,centralwiny=32,nchannels = 3,resolution_fun = lambda x: bad_res_func(x,(res,res)), resolution_fun_type = 'down')
    else:
        sensor = syc.Sensor(winx=32,winy=32,centralwinx=res//2,centralwiny=res//2,nchannels = 3,resolution_fun = lambda x: bad_res102(x,(res,res)), resolution_fun_type = 'down')
    agent = syc.Agent(max_q = [scene.maxx-sensor.hp.winx,scene.maxy-sensor.hp.winy])
    #Setting the coordinates to visit
    if type(trajectory_list) is int:
        if trajectory_list:
            np.random.seed(trajectory_list)
        starting_point = np.array([agent.max_q[0]//2,agent.max_q[1]//2])
        steps  = []
        for j in range(sample):
            steps.append(starting_point*1)
            starting_point += np.random.randint(-2,3,2) 

        if mixed_state:
            q_sequence = np.array(steps).astype(int)
            if step % 10000 == 0:
                # recorder.plot()
                saccade_RL.dqn.save_nwk_param(hp.this_run_path +
                                              'tempX_saccade.nwk')
                drift_net.save_nwk_param(hp.this_run_path + 'tempX_drift.nwk')
                # debug_policy_plot()
            if step % 100000 == 0:
                recorder.save(hp.this_run_path + recorder_file)
    # recorder.save(hp.this_run_path+recorder_file)


if __name__ == "__main__":

    recorder = Recorder(n=6)

    sensor = syc.Sensor(fisheye=fy_dict)
    saccade_agent = syc.Saccadic_Agent()

    reward = syc.Rewards(reward_types=['network'], relative_weights=[100.0])
    # observation_size = sensor.hp.winx*sensor.hp.winy*2
    saccade_observation_size = 64 * 64 + 100
    # saccade_RL = DeepQNetwork(np.prod(saccade_agent.max_q), saccade_observation_size,
    saccade_RL = DeepQNetwork(
        64 * 64,
        saccade_observation_size,
        n_features_shaped=list(np.shape(sensor.dvs_view)) + [1],
        shape_fun=None,
        reward_decay=0.99,
        replace_target_iter=10,
        memory_size=100000,
        e_greedy_increment=0.0001,
    # with open('../video_datasets/liron_images/shuffled_images.pkl', 'rb') as f:
    #     images = pickle.load(f)

    ##
    # mnist = MNIST('/home/bnapp/datasets/mnist/')
    # images, labels = mnist.load_training()
    (images, labels), (images_test,labels_test) = keras.datasets.mnist.load_data(path="mnist.npz")
    if hp.test_mode or hp.eval_mode:
        (images, labels) = (images_test, labels_test)
    img = build_mnist_padded(1. / 256 * np.reshape(images[0], [1, 28, 28])) #just to initialize scene with correct size

    scene = syc.Scene(image_matrix=img)
    sensor = syc.Sensor(winx=56, winy=56,
                        centralwinx=hp.resolution//2,
                        centralwiny=hp.resolution//2,
                        resolution_fun=lambda x: bad_res102(x, (hp.resolution, hp.resolution)),
                        resolution_fun_type='down')
    # sensor.hp.resolution_fun = lambda x: bad_res101(x, (hp.resolution, hp.resolution))
    agent = syc.Agent(max_q=[scene.maxx - sensor.hp.winx, scene.maxy - sensor.hp.winy])

    reward = syc.Rewards(reward_types=['central_rms_intensity', 'speed','manual_reward'],relative_weights=[hp.intensity_reward,hp.speed_reward,hp.loss_reward])
    # observation_size = sensor.hp.winx*sensor.hp.winy*2
    observation_size = 530#2*(hp.resolution//2)**2+2

    rising_beta_schedule = [[hp.beta_t1 // hp.steps_between_learnings, hp.beta_b1], [hp.beta_t2 // hp.steps_between_learnings, hp.beta_b2]]
    flat_beta_schedule = [[hp.beta_t1 // hp.steps_between_learnings, hp.beta_b2], [hp.beta_t2 // hp.steps_between_learnings, hp.beta_b2]]

    # rising_beta_schedule = [[400000 // hp.steps_between_learnings, 0.1], [700000 // hp.steps_between_learnings, 1]]
    # flat_beta_schedule = [[400000 // hp.steps_between_learnings, 1.0], [700000 // hp.steps_between_learnings, 1]]
Exemple #6
0
                #     print('--------------------------------')
                #     old_policy_map = policy_map
            if step%10000 ==0:
                    recorder.plot()
                    RL.dqn.save_nwk_param('temp3.nwk')

if __name__ == "__main__":

    vertical_edge_mat = np.zeros([28,28])
    vertical_edge_mat[:,14:] = 1.0
    recorder = Recorder(n=6)
    debu2el = np.diag(np.ones([10-1]),k=1)+np.eye(10)
    # debu2el = debu2el[:-1,:]

    scene = syc.Scene(image_matrix=vertical_edge_mat)
    sensor = syc.Sensor()
    agent = syc.Agent(max_q = [scene.maxx-sensor.hp.winx,scene.maxy-sensor.hp.winy])
    reward = syc.Rewards()
    RL = DeepQNetwork(len(agent.hp.action_space), sensor.hp.winx+2,#sensor.frame_size+2,
                      reward_decay=0.9,
                      e_greedy=0.99,
                      e_greedy0=0.25,
                      replace_target_iter=10,
                      memory_size=30000,
                      e_greedy_increment=0.001,
                      state_table=None
                      )


    hp.scene = scene.hp
    hp.sensor = sensor.hp
Exemple #7
0
    # with open('../video_datasets/liron_images/shuffled_images.pkl', 'rb') as f:
    #     images = pickle.load(f)

    ##
    # mnist = MNIST('/home/bnapp/datasets/mnist/')
    # images, labels = mnist.load_training()
    (images,
     labels), (images_test,
               labels_test) = keras.datasets.mnist.load_data(path="mnist.npz")
    if hp.test_mode or hp.eval_mode:
        (images, labels) = (images_test, labels_test)
    img = build_mnist_padded(1. / 256 * np.reshape(
        images[0], [1, 28, 28]))  #just to initialize scene with correct size

    scene = syc.Scene(image_matrix=img)
    sensor = syc.Sensor(winx=56, winy=56, centralwinx=28, centralwiny=28)
    sensor.hp.resolution_fun = lambda x: bad_res101(x, (hp.resolution, hp.
                                                        resolution))
    agent = syc.Agent(
        max_q=[scene.maxx - sensor.hp.winx, scene.maxy - sensor.hp.winy])

    reward = syc.Rewards(
        reward_types=['central_rms_intensity', 'speed', 'manual_reward'],
        relative_weights=[
            hp.intensity_reward, hp.speed_reward, hp.loss_reward
        ])
    # observation_size = sensor.hp.winx*sensor.hp.winy*2
    observation_size = 2080

    rising_beta_schedule = [[
        hp.beta_t1 // hp.steps_between_learnings, hp.beta_b1
def create_dataset(images, labels, res, sample=5, mixed_state=True):
    '''
    Creates a torch dataloader object of syclop outputs 
    from a list of images and labels.
    
    Parameters
    ----------
    images : List object holding the images to proces
    labels : List object holding the labels
    res : resolution dawnsampling factor - to be used in cv.resize(orig_img, res)
    sample: the number of samples to have in syclop
    mixed_state : if False, use the same trajectory on every image.

    Returns
    -------
    train_dataloader, test_dataloader - torch DataLoader class objects

    '''
    count = 0
    ts_images = []
    dvs_images = []
    count = 0
    #create subplot to hold examples from the dataset
    fig, ax = plt.subplots(2, 5)
    i = 0  #indexises for the subplot for image and for syclop vision
    for img in images:
        orig_img = np.reshape(img, [28, 28])
        #Set the padded image
        img = misc.build_mnist_padded(1. / 256 * np.reshape(img, [1, 28, 28]))

        if count < 5:
            ax[0, i].imshow(orig_img)
            plt.title(labels[count])
        #Set the sensor and the agent
        scene = syc.Scene(image_matrix=img)
        sensor = syc.Sensor(winx=56, winy=56, centralwinx=28, centralwiny=28)
        agent = syc.Agent(
            max_q=[scene.maxx - sensor.hp.winx, scene.maxy - sensor.hp.winy])
        #Setting the coordinates to visit
        starting_point = np.array([agent.max_q[0] // 2, agent.max_q[1] // 2])

        steps = []
        for j in range(5):
            steps.append(starting_point * 1)
            starting_point += np.random.randint(-5, 5, 2)

        if mixed_state:
            q_sequence = np.array(steps).astype(int)
        else:
            if count == 0:
                q_sequence = np.array(steps).astype(int)
        #Setting the resolution function - starting with the regular resolution
        sensor.hp.resolution_fun = lambda x: bad_res101(x, (res, res))
        #Create empty lists to store the syclops outputs
        imim = []
        dimim = []
        agent.set_manual_trajectory(manual_q_sequence=q_sequence)
        #Run Syclop for 20 time steps
        for t in range(5):
            agent.manual_act()
            sensor.update(scene, agent)
            imim.append(sensor.central_frame_view)
            dimim.append(sensor.central_dvs_view)
        #Create a unified matrix from the list
        if count < 5:
            ax[1, i].imshow(imim[0])
            plt.title(labels[count])
            i += 1

        imim = np.array(imim)
        dimim = np.array(dimim)
        #Add current proccessed image to lists
        ts_images.append(imim)
        dvs_images.append(dimim)
        count += 1

    ts_train = ts_images[:55_000]
    train_labels = labels[:55_000]
    ts_val = ts_images[55_000:]
    val_labels = labels[55_000:]

    dvs_train = dvs_images[:55_000]
    dvs_val = dvs_images[55_000:]

    class mnist_dataset(Dataset):
        def __init__(self, data, labels, transform=None):

            self.data = data
            self.labels = labels

            self.transform = transform

        def __len__(self):
            return len(self.data)

        def __getitem__(self, idx):
            '''
            args idx (int) :  index

            returns: tuple(data, label)
            '''
            data = self.data[idx]
            label = self.labels[idx]

            if self.transform:
                data = self.transform(data)
                return data, label
            else:
                return data, label

        def dataset(self):
            return self.data

        def labels(self):
            return self.labels

    train_dataset = mnist_dataset(dvs_train, train_labels)
    test_dataset = mnist_dataset(dvs_val, val_labels)
    batch = 64
    train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=batch,
                                                   shuffle=True)
    test_dataloader = torch.utils.data.DataLoader(test_dataset,
                                                  batch_size=batch,
                                                  shuffle=True)

    return train_dataloader, test_dataloader
    #finish with updates
    if hp.fisheye_file is None:
        fy_dict = None
    else:
        with open(hp.fisheye_file, 'rb') as f:
            fy_dict = pickle.load(f)
    hp.this_run_name = sys.argv[
        0] + '_noname_' + hp.run_name_suffix + '_' + str(int(
            time.time())) + '_' + lsbjob
    hp.grayscale = hp.color == 'grayscale'
    nchannels = 1 if hp.grayscale else 3

    recorder = Recorder(n=6)

    sensor = syc.Sensor(fisheye=fy_dict,
                        centralwinx=32,
                        centralwiny=32,
                        nchannels=nchannels)

    saccade_agent = syc.Saccadic_Agent()

    reward = syc.Rewards(reward_types=['network'], relative_weights=[100.0])
    # observation_size = sensor.hp.winx*sensor.hp.winy*2
    if hp.color == 'grayscale':
        visual_features = 64 * 64
    elif hp.color == 'rgb':
        visual_features = 64 * 64 * 3
    saccade_observation_size = visual_features + hp.drift_state_size

    n_features_shaped = list(np.shape(sensor.dvs_view))
    if len(n_features_shaped) < 3:  # to support 2 and 3d dvs views
        n_features_shaped.append(nchannels)
def create_mnist_dataset(images,
                         labels,
                         res,
                         sample=5,
                         mixed_state=True,
                         add_traject=True,
                         q_0=None,
                         alpha=0,
                         trajectory_list=None,
                         random_trajectories=False,
                         return_datasets=False,
                         add_seed=70000,
                         show_fig=False,
                         mix_res=False,
                         bad_res_func=None,
                         up_sample=False,
                         acceleration_mode=False):
    #mix_res = False, bad_res_func = bad_res102, up_sample = False):
    '''
    Creates a torch dataloader object of syclop outputs
    from a list of images and labels.

    Parameters
    ----------
    images : List object holding the images to proces
    labels : List object holding the labels
    res : resolution dawnsampling factor - to be used in cv.resize(orig_img, res)
    sample: the number of samples to have in syclop
    mixed_state : if False, use the same trajectory on every image.
    trajectory_list : uses a preset trajectory from the list.
    return_datasets: rerutns datasets rather than dataloaders
    add_seed : creates a random seed option to have a limited number of random
               trajectories, default = 20 (number of trajectories)
    show_fig : to show or not an example of the dataset, defoult = False
    mix_res  : Weather or not to create a mix of resolution in each call to
                the dataset, to use to learn if the network is able to learn
                mixed resolution to gain better performance in the lower res
                part. default =
    bed_res_func : The function that creats the bad resolution images
    up_sample    : weather the bad_res_func used up sampling or not, it changes the central view
                    values.

    Returns
    -------
    train_dataloader, test_dataloader - torch DataLoader class objects
    '''
    count = 0
    ts_images = []
    dvs_images = []
    q_seq = []
    count = 0
    res_orig = res * 1
    if show_fig:
        # create subplot to hold examples from the dataset
        fig, ax = plt.subplots(2, 5)
        i = 0  # indexises for the subplot for image and for syclop vision
    for img_num, img in enumerate(images):

        if add_seed:
            np.random.seed(random.randint(0, add_seed))

        if mix_res:
            res = random.randint(6, 10)
            if img_num >= 55000:
                res = res_orig
        orig_img = np.reshape(img, [28, 28])
        # Set the padded image
        img = misc.build_mnist_padded(1. / 256 * np.reshape(img, [1, 28, 28]))
        if img_num == 42:
            print('Are we random?', np.random.randint(1, 20))
        if show_fig:
            if count < 5:
                ax[0, i].imshow(orig_img)
                plt.title(labels[count])
        # Set the sensor and the agent
        scene = syc.Scene(image_matrix=img)
        if up_sample:
            sensor = syc.Sensor(
                winx=56,
                winy=56,
                centralwinx=28,
                centralwiny=28,
                resolution_fun=lambda x: bad_res_func(x, (res, res)),
                resolution_fun_type='down')
        else:
            sensor = syc.Sensor(
                winx=56,
                winy=56,
                centralwinx=res // 2,
                centralwiny=res // 2,
                resolution_fun=lambda x: bad_res_func(x, (res, res)),
                resolution_fun_type='down')

        agent = syc.Agent(
            max_q=[scene.maxx - sensor.hp.winx, scene.maxy - sensor.hp.winy])
        # Setting the coordinates to visit
        if trajectory_list is None:
            if img_num == 0 or random_trajectories:
                starting_point = np.array(
                    [agent.max_q[0] // 2, agent.max_q[1] // 2])
                steps = []
                qdot = 0
                for j in range(sample):
                    steps.append(starting_point * 1)
                    if acceleration_mode:
                        qdot += np.random.randint(-1, 2, 2)
                        starting_point += qdot
                    else:
                        starting_point += np.random.randint(-5, 6, 2)

                if mixed_state:
                    q_sequence = np.array(steps).astype(int)
                else:
                    if count == 0:
                        q_sequence = np.array(steps).astype(int)
            if q_0 is not None:
                q_sequence = (q_0 * (1 - alpha) +
                              q_sequence * alpha).astype(int)
        else:
            q_sequence = np.array(trajectory_list[img_num]).astype(int)

        # Setting the resolution function - starting with the regular resolution

        # Create empty lists to store the syclops outputs
        imim = []
        dimim = []
        agent.set_manual_trajectory(manual_q_sequence=q_sequence)
        # Run Syclop for 20 time steps
        for t in range(len(q_sequence)):
            agent.manual_act()
            sensor.update(scene, agent)
            ############################################################################
            #############CHANGED FROM sensor.central_frame_view TO sensor.frame_view####
            ############################################################################
            imim.append(sensor.frame_view)
            dimim.append(sensor.dvs_view)
        # Create a unified matrix from the list
        if show_fig:
            if count < 5:
                ax[1, i].imshow(imim[0])
                plt.title(labels[count])
                i += 1

        imim = np.array(imim)
        dimim = np.array(dimim)
        # Add current proccessed image to lists
        ts_images.append(imim)
        dvs_images.append(dimim)
        q_seq.append(q_sequence)  # / 128)
        count += 1

    if add_traject:  # If we add the trjectories the train list will become a list of lists, the images and the
        # corrosponding trajectories, we will change the dataset structure as well. Note the the labels stay the same.
        ts_train = (ts_images[:55000], q_seq[:55000])
        train_labels = labels[:55000]
        ts_val = (ts_images[55000:], q_seq[55000:])
        val_labels = labels[55000:]

    else:
        ts_train = ts_images[:55000]
        train_labels = labels[:55000]
        ts_val = ts_images[55000:]
        val_labels = labels[55000:]

    dvs_train = dvs_images[:55000]
    dvs_val = dvs_images[55000:]

    class mnist_dataset():
        def __init__(self, data, labels, add_traject=False, transform=None):

            self.data = data
            self.labels = labels

            self.add_traject = add_traject
            self.transform = transform

        def __len__(self):
            if self.add_traject:
                return len(self.data[0])
            else:
                return len(self.data[0])

        def __getitem__(self, idx):
            '''
            args idx (int) :  index
            returns: tuple(data, label)
            '''
            if self.add_traject:
                img_data = self.data[0][idx]
                traject_data = self.data[1][idx]
                label = self.labels[idx]
                return img_data, traject_data, label
            else:
                data = self.data[idx]

            if self.transform:
                data = self.transform(data)
                return data, label
            else:
                return data, label

        def dataset(self):
            return self.data

        def labels(self):
            return self.labels

    train_dataset = mnist_dataset(ts_train, train_labels, add_traject=True)
    test_dataset = mnist_dataset(ts_val, val_labels, add_traject=True)
    batch = 64

    if return_datasets:
        return train_dataset, test_dataset
Exemple #11
0
def create_mnist_dataset(images,
                         labels,
                         res,
                         sample=5,
                         mixed_state=True,
                         add_traject=True,
                         trajectory_list=None,
                         return_datasets=False,
                         add_seed=20,
                         show_fig=False):
    '''
    Creates a torch dataloader object of syclop outputs 
    from a list of images and labels.
    
    Parameters
    ----------
    images : List object holding the images to proces
    labels : List object holding the labels
    res : resolution dawnsampling factor - to be used in cv.resize(orig_img, res)
    sample: the number of samples to have in syclop
    mixed_state : if False, use the same trajectory on every image.
    trajectory_list : uses a preset trajectory from the list.
    return_datasets: rerutns datasets rather than dataloaders
    add_seed : creates a random seed option to have a limited number of random
               trajectories, defoult = 20 (number of trajectories)
    show_fig : to show or not an example of the dataset, defoult = False
    Returns
    -------
    train_dataloader, test_dataloader - torch DataLoader class objects

    '''
    count = 0
    ts_images = []
    dvs_images = []
    q_seq = []
    count = 0

    if show_fig:
        #create subplot to hold examples from the dataset
        fig, ax = plt.subplots(2, 5)
        i = 0  #indexises for the subplot for image and for syclop vision
    for img_num, img in enumerate(images):
        if add_seed:
            np.random.seed(torch.randint(1, add_seed, (1, )))

        orig_img = np.reshape(img, [28, 28])
        #Set the padded image
        img = misc.build_mnist_padded(1. / 256 * np.reshape(img, [1, 28, 28]))

        if show_fig:
            if count < 5:
                ax[0, i].imshow(orig_img)
                plt.title(labels[count])
        #Set the sensor and the agent
        scene = syc.Scene(image_matrix=img)
        sensor = syc.Sensor(winx=56, winy=56, centralwinx=28, centralwiny=28)
        agent = syc.Agent(
            max_q=[scene.maxx - sensor.hp.winx, scene.maxy - sensor.hp.winy])
        #Setting the coordinates to visit
        if trajectory_list is None:
            starting_point = np.array(
                [agent.max_q[0] // 2, agent.max_q[1] // 2])
            steps = []
            for j in range(5):
                steps.append(starting_point * 1)
                starting_point += np.random.randint(-5, 5, 2)

            if mixed_state:
                q_sequence = np.array(steps).astype(int)
            else:
                if count == 0:
                    q_sequence = np.array(steps).astype(int)
        else:
            q_sequence = np.array(trajectory_list[img_num]).astype(int)

        #Setting the resolution function - starting with the regular resolution
        sensor.hp.resolution_fun = lambda x: bad_res101(x, (res, res))
        #Create empty lists to store the syclops outputs
        imim = []
        dimim = []
        agent.set_manual_trajectory(manual_q_sequence=q_sequence)
        #Run Syclop for 20 time steps
        for t in range(5):
            agent.manual_act()
            sensor.update(scene, agent)
            imim.append(sensor.central_frame_view)
            dimim.append(sensor.central_dvs_view)
        #Create a unified matrix from the list
        if show_fig:
            if count < 5:
                ax[1, i].imshow(imim[0])
                plt.title(labels[count])
                i += 1

        imim = np.array(imim)
        dimim = np.array(dimim)
        #Add current proccessed image to lists
        ts_images.append(imim)
        dvs_images.append(dimim)
        q_seq.append(q_sequence)
        count += 1

    if add_traject:  #If we add the trjectories the train list will become a list of lists, the images and the
        #corrosponding trajectories, we will change the dataset structure as well. Note the the labels stay the same.
        ts_train = (ts_images[:55000], q_seq[:55000])
        train_labels = labels[:55000]
        ts_val = (ts_images[55000:], q_seq[55000:])
        val_labels = labels[55000:]

    else:
        ts_train = ts_images[:55000]
        train_labels = labels[:55000]
        ts_val = ts_images[55000:]
        val_labels = labels[55000:]

    dvs_train = dvs_images[:55000]
    dvs_val = dvs_images[55000:]

    class mnist_dataset():
        def __init__(self, data, labels, add_traject=False, transform=None):

            self.data = data
            self.labels = labels

            self.add_traject = add_traject
            self.transform = transform

        def __len__(self):
            if self.add_traject:
                return len(self.data[0])
            else:
                return len(self.data[0])

        def __getitem__(self, idx):
            '''
            args idx (int) :  index

            returns: tuple(data, label)
            '''
            if self.add_traject:
                img_data = self.data[0][idx]
                traject_data = self.data[1][idx]
                label = self.labels[idx]
                return img_data, traject_data, label
            else:
                data = self.data[idx]

            if self.transform:
                data = self.transform(data)
                return data, label
            else:
                return data, label

        def dataset(self):
            return self.data

        def labels(self):
            return self.labels

    train_dataset = mnist_dataset(ts_train, train_labels, add_traject=True)
    test_dataset = mnist_dataset(ts_val, val_labels, add_traject=True)
    batch = 64
    train_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                   batch_size=batch,
                                                   shuffle=True)
    test_dataloader = torch.utils.data.DataLoader(test_dataset,
                                                  batch_size=batch,
                                                  shuffle=True)

    if return_datasets:
        return train_dataset, test_dataset
    else:
        return train_dataloader, test_dataloader, ts_train, train_labels, q_sequence