def get_descs_and_labels(net: MLNet, sess: tf.Session, modal, paths_with_labels, process_fn, batch_size): """ This function computes description vectors for image and text samples. """ if net.is_training: raise Exception("should not run this in training mode") if net.is_retrieving: raise Exception("should not run this in retrieving mode") descriptors = [] labels = [] loader = DataLoader(paths_with_labels, batch_size, shuffle=False, process_fn=process_fn) for batch in range(loader.n_batches): batch_data, batch_labels = loader.get_batch_by_index(batch) batch_data = split_and_pack(batch_data) if modal == 1: feed_dict = {} for ph, data in zip(net.ph1, batch_data): feed_dict[ph] = data batch_descs = net.descriptors_1.eval(session=sess, feed_dict=feed_dict) elif modal == 2: feed_dict = {} for ph, data in zip(net.ph2, batch_data): feed_dict[ph] = data batch_descs = net.descriptors_2.eval(session=sess, feed_dict=feed_dict) else: raise Exception("modal should be either 1 or 2") descriptors.append(batch_descs) labels.append(batch_labels) if loader.n_remain > 0: batch_data, batch_labels = loader.get_remaining() batch_data = split_and_pack(batch_data) if modal == 1: feed_dict = {} for ph, data in zip(net.ph1, batch_data): feed_dict[ph] = data batch_descs = net.descriptors_1.eval(session=sess, feed_dict=feed_dict) elif modal == 2: feed_dict = {} for ph, data in zip(net.ph2, batch_data): feed_dict[ph] = data batch_descs = net.descriptors_2.eval(session=sess, feed_dict=feed_dict) else: raise Exception("modal should be either 1 or 2") descriptors.append(batch_descs[:loader.n_remain]) labels.append(batch_labels[:loader.n_remain]) descriptors = np.concatenate(descriptors, axis=0) labels = np.concatenate(labels, axis=0) return descriptors, labels
class DataPairLoader: """ This is an abstract class DataPairLoader accepts two lists of data files and allows the training/testing program to get data by batches according to some order How the data are paired depends on the implementation of the abstract method generate_pair_indices. for more information, see: https://gitlab.com/crossmodal2018/documentations/blob/master/%E6%95%B0%E6%8D%AE%E5%88%97%E8%A1%A8%E6%96%87%E4%BB%B6.pdf """ __metaclass__ = ABCMeta # this class is an abstract class def __init__(self, paths_with_labels_1, paths_with_labels_2, batch_size, n_classes, shuffle, n_threads=8, process_fn_1=None, process_fn_2=None): random.seed(int(1e6 * (time.time() % 1))) self.paths_with_labels_1 = paths_with_labels_1 self.paths_with_labels_2 = paths_with_labels_2 # parameters self.n_classes = n_classes self.batch_size = batch_size self.n_threads = n_threads self.n_samples_1 = len(self.paths_with_labels_1) self.n_samples_2 = len(self.paths_with_labels_2) self.dtype_labels = 'int32' # generate data pairs indices_1, indices_2 = self.generate_pair_indices() if shuffle: indices = list(zip(indices_1, indices_2)) random.shuffle(indices) indices_1, indices_2 = zip(*indices) loader_1_list = [self.paths_with_labels_1[i] for i in indices_1] loader_2_list = [self.paths_with_labels_2[i] for i in indices_2] # initialize loaders self.loader_1 = DataLoader(loader_1_list, batch_size=self.batch_size, n_threads=n_threads, process_fn=process_fn_1) self.loader_2 = DataLoader(loader_2_list, batch_size=self.batch_size, n_threads=n_threads, process_fn=process_fn_2) # state self.n_pairs = len(indices_1) self.n_batches = math.floor(self.n_pairs / self.batch_size) self.n_remain = self.n_pairs % batch_size self.i = 0 # async_load self.async_load_pool = [None, None, None] self.async_load_thread = None @abstractmethod def generate_pair_indices(self): return [], [] def reset(self): self.i = 0 self.loader_1.reset() self.loader_2.reset() def set_batch_index(self, i): self.i = i self.loader_1.set_batch_index(i) self.loader_2.set_batch_index(i) def get_batch_by_index(self, i): data_1, labels_1 = self.loader_1.get_batch_by_index(i) data_2, labels_2 = self.loader_2.get_batch_by_index(i) labels = (labels_1 == labels_2).astype(self.dtype_labels) data_1 = split_and_pack(data_1) data_2 = split_and_pack(data_2) return data_1, data_2, labels def get_remaining(self): data_1, labels_1 = self.loader_1.get_remaining() data_2, labels_2 = self.loader_2.get_remaining() labels = (labels_1 == labels_2).astype(self.dtype_labels) data_1 = split_and_pack(data_1) data_2 = split_and_pack(data_2) return data_1, data_2, labels def next_batch(self): if self.i < self.n_batches: data_1, data_2, labels = self.get_batch_by_index(self.i) self.i += 1 data_1 = split_and_pack(data_1) data_2 = split_and_pack(data_2) return data_1, data_2, labels else: return [], [], None def get_pair_by_index(self, i): datup_1, label_1 = self.loader_1.get_datup_at_index(i) datup_2, label_2 = self.loader_2.get_datup_at_index(i) label = label_1 == label_2 return datup_1, datup_2, int(label) def async_load_batch(self, i): if self.async_load_thread is not None: self.async_load_thread.join() self.async_load_thread = AsyncLoadThread(self, i) self.async_load_thread.start() def get_async_loaded(self): if self.async_load_thread is None: raise Exception('Did not load anything') self.async_load_thread.join() data_1, data_2, labels = self.async_load_pool return data_1, data_2, labels