コード例 #1
0
ファイル: sampler_asy.py プロジェクト: y2242794082/SBIR
class SamplingDataFetcher(Process):
    def __init__(self, queue, sketch_dir,  mean, hard_ratio, batch_size):
        """Setup the StrokeSamplingDataLayer."""
        super(SamplingDataFetcher, self).__init__()
        #        mean = mean
        self._queue = queue
        self.sketch_transformer = Transformer(225, 1, mean)
        self.sketch_dir = sketch_dir
        self.sketch_bm = MemoryBlockManager(sketch_dir)
        self.hard_ratio = hard_ratio
        self.mini_batchsize = batch_size

    def get_next_batch(self):
        sketch_batch = []
        # sampling
        sketch_inds = self.sketch_bm.pop_batch_inds_circular(self.mini_batchsize)
        # fetch data
        for (sketch_id) in zip(sketch_inds):
            sketch_batch.append(self.sketch_bm.get_sample(sketch_id).reshape((256, 256, 1)))
        # apply transform
        sketch_batch = self.sketch_transformer.transform_all(sketch_batch).astype(np.uint8)
        self._queue.put(sketch_batch)

    def run(self):
        print ('SamplingDataFetcher started')
        while True:
            self.get_next_batch()
コード例 #2
0
ファイル: sampler_asy.py プロジェクト: y2242794082/SBIR
 def __init__(self, queue, sketch_dir,  mean, hard_ratio, batch_size):
     """Setup the StrokeSamplingDataLayer."""
     super(SamplingDataFetcher, self).__init__()
     #        mean = mean
     self._queue = queue
     self.sketch_transformer = Transformer(225, 1, mean)
     self.sketch_dir = sketch_dir
     self.sketch_bm = MemoryBlockManager(sketch_dir)
     self.hard_ratio = hard_ratio
     self.mini_batchsize = batch_size
コード例 #3
0
ファイル: triplet_sampler.py プロジェクト: CityU-HAN/DeepSBIR
 def __init__(self, queue, layer_params):
     """Setup the TripletSamplingDataLayer."""
     super(TripletSamplingDataFetcher, self).__init__()
     self._queue = queue
     mean = layer_params['mean']
     self._phase = layer_params['phase']
     self.sketch_transformer = Transformer(225, 1, mean, self._phase == "TRAIN")
     self.anc_bm = MemoryBlockManager(layer_params['sketch_dir'])
     self.pos_neg_bm = MemoryBlockManager(layer_params['image_dir'])
     self.hard_ratio = layer_params['hard_ratio']
     self.mini_batchsize = layer_params['batch_size']
     self.load_triplets(layer_params['triplet_path'])
コード例 #4
0
 def __init__(self, queue, sketch_dir, image_dir, triplet_path, mean, hard_ratio, batch_size, phase):
     """Setup the TripletSamplingDataLayer."""
     super(TripletSamplingDataFetcher, self).__init__()
     #        mean = mean
     self._queue = queue
     self._phase = phase
     self.sketch_transformer = Transformer(225, 1, mean, self._phase == "TRAIN")
     self.sketch_dir = sketch_dir
     self.anc_bm = MemoryBlockManager(sketch_dir)
     self.pos_neg_bm = MemoryBlockManager(image_dir)
     self.hard_ratio = hard_ratio
     self.mini_batchsize = batch_size
     self.load_triplets(triplet_path)
コード例 #5
0
 def __init__(self, queue, layer_params):
     """Setup the TripletSamplingDataLayer."""
     super(TripletSamplingDataFetcher, self).__init__()
     self._queue = queue
     mean = layer_params['mean']
     self._phase = layer_params['phase']
     self.sketch_transformer = Transformer(225, 1, mean,
                                           self._phase == "TRAIN")
     self.anc_bm = MemoryBlockManager(layer_params['sketch_dir'])
     self.pos_neg_bm = MemoryBlockManager(layer_params['image_dir'])
     self.hard_ratio = layer_params['hard_ratio']
     self.mini_batchsize = layer_params['batch_size']
     self.load_triplets(layer_params['triplet_path'])
コード例 #6
0
class TripletSamplingDataFetcher_lot(Process):
    def __init__(self, queue, sketch_dir, dir1, dir2, triplet_path, mean, hard_ratio, batch_size, phase):
        """Setup the TripletSamplingDataLayer."""
        super(TripletSamplingDataFetcher_lot, self).__init__()
        #        mean = mean
        self._queue = queue
        self._phase = phase
        self.sketch_transformer = Transformer(225, 1, mean, self._phase == "TRAIN")
        self.sketch_dir = sketch_dir
        self.anc_bm = MemoryBlockManager(sketch_dir)
        self.pos_neg_bm1 = MemoryBlockManager(dir1)
        self.pos_neg_bm2 = MemoryBlockManager(dir2)
        self.hard_ratio = hard_ratio
        self.mini_batchsize = batch_size
        self.load_triplets(triplet_path)

    def load_triplets(self, triplet_path):
        self.triplets, self.neg_list = load_triplets(triplet_path, self._phase)

    def get_next_batch(self):
        anc_batch = []; pos_batch1 = []; neg_batch1 = []; pos_batch2 = []; neg_batch2 = [];
        # sampling
        anc_inds = self.anc_bm.pop_batch_inds_circular(self.mini_batchsize)
        if 'handbags' in self.sketch_dir:
            # positive are always true match
            pos_inds, neg_inds = sample_triplets_trueMatch(anc_inds, self._phase)
        else:
            pos_inds, neg_inds = sample_triplets_pos_neg(anc_inds, self.triplets, self.neg_list, self.hard_ratio)

        # fetch data
        for (anc_id, pos_id, neg_id) in zip(anc_inds, pos_inds, neg_inds):
            anc_batch.append(self.anc_bm.get_sample(anc_id).reshape((256, 256, 1)))
            pos_batch1.append(self.pos_neg_bm1.get_sample(pos_id).reshape((256, 256, 1)))
            neg_batch1.append(self.pos_neg_bm1.get_sample(neg_id).reshape((256, 256, 1)))
            pos_batch2.append(self.pos_neg_bm2.get_sample(pos_id).reshape((256, 256, 1)))
            neg_batch2.append(self.pos_neg_bm2.get_sample(neg_id).reshape((256, 256, 1)))
        # apply transform
        anc_batch = self.sketch_transformer.transform_all(anc_batch).astype(np.uint8)
        pos_batch1 = self.sketch_transformer.transform_all(pos_batch1).astype(np.uint8)
        neg_batch1 = self.sketch_transformer.transform_all(neg_batch1).astype(np.uint8)
        pos_batch2 = self.sketch_transformer.transform_all(pos_batch2).astype(np.uint8)
        neg_batch2 = self.sketch_transformer.transform_all(neg_batch2).astype(np.uint8)
        self._queue.put((anc_batch, pos_batch1, neg_batch1, pos_batch2, neg_batch2))

    def run(self):
        print ('TripletSamplingDataFetcher started')
        while True:
            self.get_next_batch()
コード例 #7
0
class TripletSamplingDataFetcher(Process):
    def __init__(self, queue, layer_params):
        """Setup the TripletSamplingDataLayer."""
        super(TripletSamplingDataFetcher, self).__init__()
        self._queue = queue
        mean = layer_params['mean']
        self._phase = layer_params['phase']
        self.sketch_transformer = Transformer(225, 1, mean,
                                              self._phase == "TRAIN")
        self.anc_bm = MemoryBlockManager(layer_params['sketch_dir'])
        self.pos_neg_bm = MemoryBlockManager(layer_params['image_dir'])
        self.hard_ratio = layer_params['hard_ratio']
        self.mini_batchsize = layer_params['batch_size']
        self.load_triplets(layer_params['triplet_path'])

    def load_triplets(self, triplet_path):
        self.triplets, self.neg_list = load_triplets(triplet_path, self._phase)

    def get_next_batch(self):
        anc_batch = []
        pos_batch = []
        neg_batch = []
        # sampling
        anc_inds = self.anc_bm.pop_batch_inds_circular(self.mini_batchsize)
        pos_inds, neg_inds = sample_triplets(anc_inds, self.triplets,
                                             self.neg_list, self.hard_ratio)
        # fetch data
        for (anc_id, pos_id, neg_id) in zip(anc_inds, pos_inds, neg_inds):
            anc_batch.append(
                self.anc_bm.get_sample(anc_id).reshape((1, 256, 256)))
            pos_batch.append(
                self.pos_neg_bm.get_sample(pos_id).reshape((1, 256, 256)))
            neg_batch.append(
                self.pos_neg_bm.get_sample(neg_id).reshape((1, 256, 256)))
        # apply transform
        anc_batch = self.sketch_transformer.transform_all(anc_batch)
        pos_batch = self.sketch_transformer.transform_all(pos_batch)
        neg_batch = self.sketch_transformer.transform_all(neg_batch)
        self._queue.put((anc_batch, pos_batch, neg_batch))

    def run(self):
        print 'TripletSamplingDataFetcher started'
        while True:
            self.get_next_batch()
コード例 #8
0
 def __init__(self, queue, sketch_data, image_data, sketch_label, image_label, triplets, queue_paras, phase):
     """Setup the TripletSamplingDataLayer."""
     super(TripletSamplingDataFetcher, self).__init__()
     #        mean = mean
     self.im_size = queue_paras['im_size']
     self.cp_size = queue_paras['cp_size']
     self.chns = queue_paras['chns']
     self.mean = queue_paras['mean']
     self.batchsize = queue_paras['batch_size']
     self.num_epoch = queue_paras['num_epochs']
     self._queue = queue
     self._phase = phase
     self.sketch_transformer = Transformer(self.cp_size, self.chns, self.mean, self._phase == "TRAIN")
     if sketch_label is None or image_label is None:
         self.has_label = False
     else:
         self.has_label = True
     self.anc_bm = MemoryBlockManager(sketch_data, sketch_label, self.has_label)
     self.pos_neg_bm = MemoryBlockManager(image_data, image_label, self.has_label)
     self.triplets = triplets
コード例 #9
0
ファイル: sbir_retrieval.py プロジェクト: CityU-HAN/DeepSBIR
def extract_features(net_, config):
    db = MemoryBlockManager(config.DB_PATH, False)
    if config.verbose:
        print 'extracting features of %s' % config.dataset
    t = time()
    num_samples = 10 * db.num_samples
    feats = np.zeros((num_samples, get_feature_dims(net_, config.feat_layer)), np.single)
    idx = 0
    while not db.eof():
        batch_data, _ = db.pop_batch(config.batchsize) # batchsize 128
        transformed = do_multiview_crop(batch_data, config.crop_dim)
        transformed = transformed - config.mean_val
        num, chns, rows, cols = transformed.shape
        net_.blobs['data'].reshape(*(num, chns, rows, cols))
        output = net_.forward(data=transformed.astype(np.float32, copy=False))
        target = output[config.feat_layer]
        feats[idx:idx+num, ::] = reshape_feature(target)
        idx += num
    if config.verbose:
        print 'feature computation completed (%0.2f sec.)' % (time()-t)
    return feats
コード例 #10
0
class TripletSamplingLayer(object):
    def __init__(self, sketch_dir, image_dir, triplet_path, mean, hard_ratio,
                 batch_size, phase):
        """Setup the TripletSamplingDataLayer."""
        self._queue = Queue(10)
        #        mean = mean
        self._phase = phase
        self.sketch_transformer = Transformer(225, 1, mean,
                                              self._phase == "TRAIN")
        self.anc_bm = MemoryBlockManager(sketch_dir)
        self.pos_neg_bm = MemoryBlockManager(image_dir)
        self.hard_ratio = hard_ratio
        self.mini_batchsize = batch_size
        self.load_triplets(triplet_path)

    def load_triplets(self, triplet_path):
        self.triplets, self.neg_list = load_triplets(triplet_path, self._phase)

    def get_next_batch(self):
        anc_batch = []
        pos_batch = []
        neg_batch = []
        # sampling
        anc_inds = self.anc_bm.pop_batch_inds_circular(self.mini_batchsize)
        pos_inds, neg_inds = sample_triplets(anc_inds, self.triplets,
                                             self.neg_list, self.hard_ratio)
        # fetch data
        for (anc_id, pos_id, neg_id) in zip(anc_inds, pos_inds, neg_inds):
            anc_batch.append(
                self.anc_bm.get_sample(anc_id).reshape((256, 256, 1)))
            pos_batch.append(
                self.pos_neg_bm.get_sample(pos_id).reshape((256, 256, 1)))
            neg_batch.append(
                self.pos_neg_bm.get_sample(neg_id).reshape((256, 256, 1)))
        # apply transform
        anc_batch = self.sketch_transformer.transform_all(anc_batch)
        pos_batch = self.sketch_transformer.transform_all(pos_batch)
        neg_batch = self.sketch_transformer.transform_all(neg_batch)
        # self._queue.put((anc_batch, pos_batch, neg_batch))
        return anc_batch, pos_batch, neg_batch
コード例 #11
0
def extract_features(net_, config):
    db = MemoryBlockManager(config.DB_PATH, False)
    if config.verbose:
        print 'extracting features of %s' % config.dataset
    t = time()
    num_samples = 10 * db.num_samples
    feats = np.zeros((num_samples, get_feature_dims(net_, config.feat_layer)),
                     np.single)
    idx = 0
    while not db.eof():
        batch_data, _ = db.pop_batch(config.batchsize)  # batchsize 128
        transformed = do_multiview_crop(batch_data, config.crop_dim)
        transformed = transformed - config.mean_val
        num, chns, rows, cols = transformed.shape
        net_.blobs['data'].reshape(*(num, chns, rows, cols))
        output = net_.forward(data=transformed.astype(np.float32, copy=False))
        target = output[config.feat_layer]
        feats[idx:idx + num, ::] = reshape_feature(target)
        idx += num
    if config.verbose:
        print 'feature computation completed (%0.2f sec.)' % (time() - t)
    return feats
コード例 #12
0
ファイル: triplet_sampler.py プロジェクト: CityU-HAN/DeepSBIR
class TripletSamplingDataFetcher(Process):
    def __init__(self, queue, layer_params):
        """Setup the TripletSamplingDataLayer."""
        super(TripletSamplingDataFetcher, self).__init__()
        self._queue = queue
        mean = layer_params['mean']
        self._phase = layer_params['phase']
        self.sketch_transformer = Transformer(225, 1, mean, self._phase == "TRAIN")
        self.anc_bm = MemoryBlockManager(layer_params['sketch_dir'])
        self.pos_neg_bm = MemoryBlockManager(layer_params['image_dir'])
        self.hard_ratio = layer_params['hard_ratio']
        self.mini_batchsize = layer_params['batch_size']
        self.load_triplets(layer_params['triplet_path'])

    def load_triplets(self, triplet_path):
        self.triplets, self.neg_list = load_triplets(triplet_path, self._phase)

    def get_next_batch(self):
        anc_batch = []; pos_batch = []; neg_batch = []
        # sampling
        anc_inds = self.anc_bm.pop_batch_inds_circular(self.mini_batchsize)
        pos_inds, neg_inds = sample_triplets(anc_inds, self.triplets, self.neg_list, self.hard_ratio)
        # fetch data
        for (anc_id, pos_id, neg_id) in zip(anc_inds, pos_inds, neg_inds):
            anc_batch.append(self.anc_bm.get_sample(anc_id).reshape((1, 256, 256)))
            pos_batch.append(self.pos_neg_bm.get_sample(pos_id).reshape((1, 256, 256)))
            neg_batch.append(self.pos_neg_bm.get_sample(neg_id).reshape((1, 256, 256)))
        # apply transform
        anc_batch = self.sketch_transformer.transform_all(anc_batch)
        pos_batch = self.sketch_transformer.transform_all(pos_batch)
        neg_batch = self.sketch_transformer.transform_all(neg_batch)
        self._queue.put((anc_batch, pos_batch, neg_batch))

    def run(self):
        print 'TripletSamplingDataFetcher started'
        while True:
            self.get_next_batch()
コード例 #13
0
class TripletSamplingDataFetcher(Process):
    def __init__(self, queue, sketch_data, image_data, sketch_label, image_label, triplets, queue_paras, phase):
        """Setup the TripletSamplingDataLayer."""
        super(TripletSamplingDataFetcher, self).__init__()
        #        mean = mean
        self.im_size = queue_paras['im_size']
        self.cp_size = queue_paras['cp_size']
        self.chns = queue_paras['chns']
        self.mean = queue_paras['mean']
        self.batchsize = queue_paras['batch_size']
        self.num_epoch = queue_paras['num_epochs']
        self._queue = queue
        self._phase = phase
        self.sketch_transformer = Transformer(self.cp_size, self.chns, self.mean, self._phase == "TRAIN")
        if sketch_label is None or image_label is None:
            self.has_label = False
        else:
            self.has_label = True
        self.anc_bm = MemoryBlockManager(sketch_data, sketch_label, self.has_label)
        self.pos_neg_bm = MemoryBlockManager(image_data, image_label, self.has_label)
        self.triplets = triplets

    def get_next_batch_data(self):
        anc_batch = []; pos_batch = []; neg_batch = []
        # sampling
        anc_inds, pos_inds, neg_inds = zip(*sample_triplets(self.triplets, self.batchsize, self.num_epoch).next())
        #pos_inds, neg_inds = sample_triplets_trueMatch(anc_inds, self.triplets)
        # fetch data
        for (anc_id, pos_id, neg_id) in zip(anc_inds, pos_inds, neg_inds):
            anc_batch.append(self.anc_bm.get_sample(anc_id).reshape((self.im_size, self.im_size, self.chns)))
            pos_batch.append(self.pos_neg_bm.get_sample(pos_id).reshape((self.im_size, self.im_size, self.chns)))
            neg_batch.append(self.pos_neg_bm.get_sample(neg_id).reshape((self.im_size, self.im_size, self.chns)))
        # apply transform
        anc_batch = self.sketch_transformer.transform_all(anc_batch)
        pos_batch = self.sketch_transformer.transform_all(pos_batch)
        neg_batch = self.sketch_transformer.transform_all(neg_batch)
        self._queue.put((anc_batch, pos_batch, neg_batch))

    def get_next_batch_data_label(self):
        anc_batch = []; pos_batch = []; neg_batch = []
        # sampling
        anc_inds, pos_inds, neg_inds = zip(*sample_triplets(self.triplets, self.batchsize, self.num_epoch).next())
        #pos_inds, neg_inds = sample_triplets_trueMatch(anc_inds, self.triplets)
        # fetch data
        for (anc_id, pos_id, neg_id) in zip(anc_inds, pos_inds, neg_inds):
            anc_batch.append(self.anc_bm.get_sample(anc_id))
            pos_batch.append(self.pos_neg_bm.get_sample(pos_id))
            neg_batch.append(self.pos_neg_bm.get_sample(neg_id))
        # apply transform
        anc_batch_data, anc_batch_label = self.sketch_transformer.transform_all_with_label(anc_batch)
        pos_batch_data, pos_batch_label = self.sketch_transformer.transform_all_with_label(pos_batch)
        neg_batch_data, neg_batch_label = self.sketch_transformer.transform_all_with_label(neg_batch)
        self._queue.put((anc_batch_data, pos_batch_data, neg_batch_data, anc_batch_label, pos_batch_label, neg_batch_label))

    def run(self):
        print ('TripletSamplingDataFetcher started')
        if self.has_label:
            while True:
                self.get_next_batch_data_label()
        else:
            while True:
                self.get_next_batch_data()