Ejemplo n.º 1
0
class BatchLoader(object):

  """
  This class abstracts away the loading of images.
  Images can either be loaded singly, or in a batch. The latter is used for
  the asyncronous data layer to preload batches while other processing is
  performed.
  """
  def __init__(self, params, result):
    self.result = result
    
    #load data
    if params['source'][-4:] == '.mat':
      self.img_data, self.img_labels = mat2py_imdb(params['source'])
      self.img_data = np.rollaxis(self.img_data,2)  #swap dims to have Num x height x width
    elif os.path.isdir(params['source']):  #assume it is lmdb
      lmdb_ = lmdbpy()
      self.img_data, self.img_labels = lmdb_.read(params['source']) #use read_encoded if the img is encoded
      self.img_data = self.img_data.squeeze()
    else:
      assert 0, 'Invalid format for source {}\n\
      need either lmdb or .mat data'.format(params['source'])
    
    if params['mean_file']==0:
      self.img_mean = 0
    elif params['mean_file'][-4:] == '.mat':
      self.img_mean = mat2py_mean(params['mean_file'])
    elif params['mean_file'][-12:] == '.binaryproto':
      self.img_mean = biproto2py(params['mean_file']).squeeze()
    else:
      assert 0, 'Invalid format for mean_file {}'.format(params['mean_file'])
    
    
    self.indexlist = range(len(self.img_labels))
    shuffle(self.indexlist)
    self._cur = 0  # current image
    
    # this class does some simple data-manipulations
    self.img_augment = SimpleAugment(mean=self.img_mean,shape=params['shape'],
                                     scale = params['scale'], rot = params['rot'])

    print "\nBatchLoader initialized with {} images".format(
        len(self.img_labels))

  def load_next_image(self):
    """
    Load the next image in a batch.
    """
    # Did we finish an epoch?
    if self._cur == len(self.indexlist):
        self._cur = 0
        shuffle(self.indexlist) 

    # Load a datum
    index = self.indexlist[self._cur]  # Get the index
    img = self.img_data[index]
    label = self.img_labels[index]

    self._cur += 1
    return (self.img_augment.augment(img), label)
Ejemplo n.º 2
0
class BatchLoader(object):
    """
  This class abstracts away the loading of images.
  Images can either be loaded singly, or in a batch. The latter is used for
  the asyncronous data layer to preload batches while other processing is
  performed.
  """
    def __init__(self, params):

        self.batch_size = params['batch_size']
        self.img_shape = params['shape']
        self.classes_per_batch = params['classes_per_batch']

        self.img_lmdb = lmdbs(params['img_source'])
        if params['skt_source'].endswith('.pkl'):
            self.skt_lmdb = svgs(params['skt_source'])
        else:
            self.skt_lmdb = lmdbs(params['skt_source'])
        self.img_labels = self.img_lmdb.get_label_list()
        self.skt_labels = self.skt_lmdb.get_label_list()
        label_ids = list(set(self.img_labels))
        NCATS = len(label_ids)
        if label_ids[0] != 0 or label_ids[-1] != NCATS - 1:
            if 'verbose' not in params:
                print 'Your data labels are not [0:{}]. Converting label ...'.format(
                    NCATS - 1)
            self.img_labels = [
                label_ids.index(label) for label in self.img_labels
            ]
            self.skt_labels = [
                label_ids.index(label) for label in self.skt_labels
            ]

        self.img_mean = biproto2py(params['mean_file']).squeeze()
        #self.skt_mean = biproto2py(params['skt_mean']).squeeze()

        self.num_classes = len(set(self.skt_labels))
        assert self.num_classes == NCATS, 'XX!!Sketch & image datasets unequal #categories'
        assert len(self.skt_labels)%self.num_classes==0, \
          'Unequal sketch training samples for each class'
        self.skt_per_class = len(self.skt_labels) / self.num_classes

        if 'hard_pos' in params:
            self.hard_sel = 1
            self.hard_pos = np.load(params['hard_pos'])['pos']
        elif 'hard_neg' in params:
            self.hard_sel = 2
            self.hard_neg = np.load(params['hard_neg'])['neg']
        elif 'hard_pn' in params:
            self.hard_sel = 3
            tmp = np.load(params['hard_pn'])
            self.hard_pos = tmp['pos']
            self.hard_neg = tmp['neg']
        else:  #hard selection turn off
            self.hard_sel = 0

        #self.img_labels_dict, self.classes = vec2dic(self.img_labels)

        self.indexlist = range(len(self.skt_labels))
        self.indexlist_img = range(len(self.img_labels))
        #self.shuffle_keeping_min_classes_per_batch()
        shuffle(self.indexlist)
        shuffle(self.indexlist_img)
        self._cur = 0  # current image
        self._cur_img = 0

        # this class does some simple data-manipulations
        self.img_augment = SimpleAugment(mean=self.img_mean,
                                         shape=self.img_shape,
                                         scale=params['scale'],
                                         rot=params['rot'])

        print "BatchLoader initialized with {} sketches, {} images of {} classes".format(
            len(self.skt_labels), len(self.img_labels), self.num_classes)
        #create threadpools for parallel augmentation
        self.pool = ThreadPool()  #4

    def load_next_pair(self, l):
        """
    Load the next pair in a batch.
    """
        # Did we finish an epoch?
        l.acquire()  #5
        if self._cur == len(self.indexlist):
            self._cur = 0
            #self.shuffle_keeping_min_classes_per_batch()
            shuffle(self.indexlist)
        if self._cur_img == len(self.indexlist_img):
            self._cur_img = 0
            #self.shuffle_keeping_min_classes_per_batch()
            shuffle(self.indexlist_img)
        # Load a sketch
        index = self.indexlist[self._cur]  # Get the sketch index
        index_img = self.indexlist_img[self._cur_img]
        self._cur += 1
        self._cur_img += 1
        l.release()  #6

        skt = self.skt_lmdb.get_datum(index).squeeze()
        label = self.skt_labels[index]

        img_i = self.img_lmdb.get_datum(index_img).squeeze()
        label_i = self.img_labels[index_img]
        #==============================================================================
        #     #randomly select paired image
        #     diff_label = np.random.choice(2)
        #     label_i = label
        #     if self.hard_sel == 0:    #hard selection turned off
        #       if diff_label: #paired image has different label
        #         while label_i == label:
        #           label_i = np.random.choice(self.classes)
        #       index_i = np.random.choice(self.img_labels_dict[str(label_i)])
        #     elif self.hard_sel == 1:  #hard positive selection
        #       if diff_label:
        #         while label_i == label:
        #           label_i = np.random.choice(self.classes)
        #         index_i = np.random.choice(self.img_labels_dict[str(label_i)])
        #       else:
        #         index_i = np.random.choice(self.hard_pos[index])
        #     elif self.hard_sel == 2:  #hard neg
        #       if diff_label:
        #         index_i = np.random.choice(self.hard_neg[index])
        #       else:
        #         index_i = np.random.choice(self.img_labels_dict[str(label_i)])
        #     else:   #hard pos and neg
        #       if diff_label:
        #         index_i = np.random.choice(self.hard_neg[index])
        #       else:
        #         index_i = np.random.choice(self.hard_pos[index])
        #
        #     img_i   = self.img_lmdb.get_image_deprocess(index_i)
        #==============================================================================

        res = dict(sketch=self.img_augment.augment(skt),
                   image=self.img_augment.augment(img_i),
                   label_s=label,
                   label_i=label_i)
        return res

    def load_next_batch(self):
        res = {}
        #7
        lock = Lock()
        threads = [self.pool.apply_async(self.load_next_pair,(lock,)) for \
                    i in range (self.batch_size)]
        thread_res = [thread.get() for thread in threads]
        res['data_s'] = np.asarray([tri['sketch']
                                    for tri in thread_res])[:, None, :, :]
        res['data_i'] = np.asarray([tri['image']
                                    for tri in thread_res])[:, None, :, :]
        res['label_s'] = np.asarray([tri['label_s'] for tri in thread_res],
                                    dtype=np.float32)[:, None]
        res['label_i'] = np.asarray([tri['label_i'] for tri in thread_res],
                                    dtype=np.float32)[:, None]
        return res


#==============================================================================
#     res['data_s'] = np.zeros((self.batch_size,1,self.outshape[0],\
#                             self.outshape[1]),dtype = np.float32)
#     res['data_i'] = np.zeros_like(res['data_a'],dtype=np.float32)
#     res['label'] = np.zeros((self.batch_size,1),dtype = np.float32)
#     for itt in range(self.batch_size):
#       trp = self.load_next_pair(1)
#       res['data_s'][itt,...] = trp['sketch']
#       res['data_i'][itt,...] = trp['image']
#       res['label'][itt,...] = trp['label']
#     return res
#==============================================================================

    def shuffle_keeping_min_classes_per_batch(self):
        shuffle(self.indexlist)

        # sort index list by class
        # (using lambda to restrict sorting to just first element of tuple
        sort_indexlist = sorted(self.indexlist,
                                key=lambda index: self.skt_labels[index])

        # make it 2D based on classes
        sort_indexlist = np.array(sort_indexlist, dtype=np.int64)
        sort_indexlist = sort_indexlist.reshape(self.num_classes,
                                                self.skt_per_class)

        # permute the class positions and flat it (make it 1D again)
        sort_indexlist = np.random.permutation(sort_indexlist)
        sort_indexlist = sort_indexlist.reshape(self.num_classes *
                                                self.skt_per_class)

        temp_indexlist = np.array([], dtype=np.int64)
        skt_per_batch_per_class = int(self.batch_size / self.classes_per_batch)

        # apply dark magic
        # (slices of classes sliced, appended together until list is over)
        for k in range(0, self.skt_per_class, skt_per_batch_per_class):
            temp_indexlist = np.append(temp_indexlist, [
                sort_indexlist[(i * self.skt_per_class) + j]
                for i in range(self.num_classes)
                for j in range(k, k + skt_per_batch_per_class)
            ])

        # convert back to list
        self.indexlist = temp_indexlist.tolist()
Ejemplo n.º 3
0
class BatchLoader(object):
    """
  This class abstracts away the loading of images.
  Images can either be loaded singly, or in a batch. The latter is used for
  the asyncronous data layer to preload batches while other processing is
  performed.
  """
    def __init__(self, params):
        #load data
        self.batch_size = params['batch_size']
        self.outshape = params['shape']

        if params['source'][-4:] == '.mat':
            self.img_data, self.img_labels = mat2py_imdb(params['source'])
            self.img_data = self.img_data.transpose(
                2, 0, 1)  #swap dims to have Num x height x width
        elif os.path.isdir(params['source']):  #assume it is lmdb
            lmdb_ = lmdbpy()
            self.img_data, self.img_labels = lmdb_.read(
                params['source'])  #use read_encoded if the img is encoded
            self.img_data = self.img_data.squeeze()
        else:
            assert 0, 'Invalid format for source {}\n\
      need either lmdb or .mat data'.format(params['source'])

        label_ids = list(set(self.img_labels))
        NCATS = len(label_ids)
        if label_ids[0] != 0 or label_ids[-1] != NCATS - 1:
            print 'Your data labels are not [0:{}]. Converting label ...'.format(
                NCATS - 1)
            self.img_labels = [
                label_ids.index(label) for label in self.img_labels
            ]

        if params['mean_file'] == 0:
            self.img_mean = 0
        elif params['mean_file'][-4:] == '.mat':
            self.img_mean = mat2py_mean(params['mean_file'])
        elif params['mean_file'][-12:] == '.binaryproto':
            self.img_mean = biproto2py(params['mean_file']).squeeze()
        else:
            assert 0, 'Invalid format for mean_file {}'.format(
                params['mean_file'])

        self.indexlist = range(len(self.img_labels))
        shuffle(self.indexlist)
        self._cur = 0  # current image

        # this class does some simple data-manipulations
        self.img_augment = SimpleAugment(mean=self.img_mean,
                                         shape=params['shape'],
                                         scale=params['scale'],
                                         rot=params['rot'])

        print "\nBatchLoader initialized with {} images".format(
            len(self.img_labels))
        self.pool = ThreadPool()  #4

    def load_next_image(self, l):
        """
    Load the next image in a batch.
    """
        # Did we finish an epoch?
        l.acquire()  #5
        if self._cur == len(self.indexlist):
            self._cur = 0
            shuffle(self.indexlist)
        # Load a datum
        index = self.indexlist[self._cur]  # Get the index
        self._cur += 1
        l.release()  #6

        img = self.img_data[index]
        label = self.img_labels[index]
        return (self.img_augment.augment(img), label)

    def load_next_batch(self):
        res = {}
        #7
        lock = Lock()
        threads = [self.pool.apply_async(self.load_next_image,(lock,)) for \
                    i in range (self.batch_size)]
        thread_res = [thread.get() for thread in threads]
        res['data'] = np.asarray([datum[0]
                                  for datum in thread_res])[:, None, :, :]
        res['label'] = np.asarray([datum[1] for datum in thread_res],
                                  dtype=np.float32)
        return res
Ejemplo n.º 4
0
class BatchLoader(object):
    """
  This class abstracts away the loading of images.
  Images can either be loaded singly, or in a batch. The latter is used for
  the asyncronous data layer to preload batches while other processing is
  performed.
  """
    def __init__(self, params):

        self.batch_size = params['batch_size']
        self.outshape = params['shape']
        self.classes_per_batch = params['classes_per_batch']

        self.img_lmdb = lmdbs(params['img_source'])
        if params['skt_source'].endswith('.pkl'):
            self.skt_lmdb = svgs(params['skt_source'])
        else:
            self.skt_lmdb = lmdbs(params['skt_source'])
        self.img_labels = self.img_lmdb.get_label_list()
        self.skt_labels = self.skt_lmdb.get_label_list()
        self.img_mean = biproto2py(params['mean_file']).squeeze()

        self.num_classes = len(set(self.skt_labels))
        assert len(self.skt_labels)%self.num_classes==0, \
          'Unequal sketch training samples for each class'
        self.skt_per_class = len(self.skt_labels) / self.num_classes

        self.hard_sel = 0

        self.img_labels_dict, self.classes = vec2dic(self.img_labels)

        self.indexlist = range(len(self.skt_labels))
        self.shuffle_keeping_min_classes_per_batch()
        self._cur = 0  # current image

        # this class does some simple data-manipulations
        self.img_augment = SimpleAugment(mean=self.img_mean,
                                         shape=params['shape'],
                                         scale=params['scale'],
                                         rot=params['rot'])

        if 'verbose' not in params:
            print "BatchLoader initialized with {} sketches, {} images of {} classes".format(
                len(self.skt_labels), len(self.img_labels), self.num_classes)
        #create threadpools for parallel augmentation
        self.pool = ThreadPool()  #4

    def load_next_triplet(self, l):
        """
    Load the next triplet in a batch.
    """
        # Did we finish an epoch?
        l.acquire()  #5
        if self._cur == len(self.indexlist):
            self._cur = 0
            self.shuffle_keeping_min_classes_per_batch()
        # Load a sketch
        index = self.indexlist[self._cur]  # Get the sketch index
        self._cur += 1
        l.release()  #6

        skt = self.skt_lmdb.get_datum(index).squeeze()
        label = self.skt_labels[index]

        #randomly select pos and neg img

        index_p = np.random.choice(self.img_labels_dict[str(label)])
        label_n = label
        while label_n == label:
            label_n = np.random.choice(self.classes)
        index_n = np.random.choice(self.img_labels_dict[str(label_n)])

        img_p = self.img_lmdb.get_datum(index_p).squeeze()
        img_n = self.img_lmdb.get_datum(index_n).squeeze()

        res = dict(anchor=self.img_augment.augment(skt),
                   pos=self.img_augment.augment(img_p),
                   neg=self.img_augment.augment(img_n),
                   label_a=label,
                   label_n=label_n)
        return res

    def load_next_batch(self):
        res = {}
        #7
        lock = Lock()
        threads = [self.pool.apply_async(self.load_next_triplet,(lock,)) for \
                    i in range (self.batch_size)]
        thread_res = [thread.get() for thread in threads]
        res['data_a'] = np.asarray([tri['anchor']
                                    for tri in thread_res])[:, None, :, :]
        res['data_p'] = np.asarray([tri['pos']
                                    for tri in thread_res])[:, None, :, :]
        res['data_n'] = np.asarray([tri['neg']
                                    for tri in thread_res])[:, None, :, :]
        res['label_a'] = np.asarray([tri['label_a'] for tri in thread_res],
                                    dtype=np.float32)
        res['label_n'] = np.asarray([tri['label_n'] for tri in thread_res],
                                    dtype=np.float32)
        return res


#==============================================================================
#     res['data_a'] = np.zeros((self.batch_size,1,self.outshape[0],\
#                             self.outshape[1]),dtype = np.float32)
#     res['data_p'] = np.zeros_like(res['data_a'],dtype=np.float32)
#     res['data_n'] = np.zeros_like(res['data_a'],dtype=np.float32)
#     res['label'] = np.zeros((self.batch_size,3),dtype = np.float32)
#     for itt in range(self.batch_size):
#       trp = self.load_next_triplet(1)
#       res['data_a'][itt,...] = trp['anchor']
#       res['data_p'][itt,...] = trp['pos']
#       res['data_n'][itt,...] = trp['neg']
#       res['label'][itt,...] = trp['label']
#     return res
#==============================================================================

    def shuffle_keeping_min_classes_per_batch(self):
        shuffle(self.indexlist)

        # sort index list by class
        # (using lambda to restrict sorting to just first element of tuple
        sort_indexlist = sorted(self.indexlist,
                                key=lambda index: self.skt_labels[index])

        # make it 2D based on classes
        sort_indexlist = np.array(sort_indexlist, dtype=np.int64)
        sort_indexlist = sort_indexlist.reshape(self.num_classes,
                                                self.skt_per_class)

        # permute the class positions and flat it (make it 1D again)
        sort_indexlist = np.random.permutation(sort_indexlist)
        sort_indexlist = sort_indexlist.reshape(self.num_classes *
                                                self.skt_per_class)

        temp_indexlist = np.array([], dtype=np.int64)
        skt_per_batch_per_class = int(self.batch_size / self.classes_per_batch)

        # apply dark magic
        # (slices of classes sliced, appended together until list is over)
        for k in range(0, self.skt_per_class, skt_per_batch_per_class):
            temp_indexlist = np.append(temp_indexlist, [
                sort_indexlist[(i * self.skt_per_class) + j]
                for i in range(self.num_classes)
                for j in range(k, k + skt_per_batch_per_class)
            ])

        # convert back to list
        self.indexlist = temp_indexlist.tolist()