def transform(self, images): if self._aug_flag: transformed_images =\ torch.zeros(len(images), 3, self._imsize, self._imsize) lr_images = torch.zeros(len(images), 3, self._imsize // self.hr_lr_ratio, self._imsize // self.hr_lr_ratio) start_time = time.time() for i in range(len(images)): current_image = images[i] current_image = self.random_crop(current_image) lr_image = self.toTensor( self.resize_lr(current_image)) * 2. - 1 current_image = self.toTensor(current_image) * 2. - 1 transformed_images[i] = current_image lr_images[i] = lr_image return wrap_Variable(transformed_images), wrap_Variable(lr_images) else: assert False return wrap_Variable(torch.FloatTensor(images.tolist()))
def next_batch_test(self, batch_size, start, max_captions): """Return the next `batch_size` examples from this data set.""" if (start + batch_size) > self._num_examples: end = self._num_examples start = end - batch_size else: end = start + batch_size sampled_images = self._images[start:end] #sampled_images = sampled_images.astype(np.float32) # from [0, 255] to [-1.0, 1.0] #sampled_images = sampled_images * (2. / 255) - 1. sampled_images, lr_images = self.transform(sampled_images) sampled_embeddings = self._embeddings[start:end] embedding_num = sampled_embeddings.size()[1] sampled_embeddings_batchs = [] sampled_captions = [] sampled_filenames = self._filenames[start:end] sampled_class_id = self._class_id[start:end] for i in range(len(sampled_filenames)): captions = self.readCaptions(sampled_filenames[i], sampled_class_id[i]) # print(captions) sampled_captions.append(captions) for i in range(np.minimum(max_captions, embedding_num)): batch = sampled_embeddings[:, i, :] sampled_embeddings_batchs.append( wrap_Variable(torch.FloatTensor(np.squeeze(batch)))) return [ sampled_images, lr_images, sampled_embeddings_batchs, self._saveIDs[start:end], sampled_captions ]
def next_batch(self, batch_size, window): """Return the next `batch_size` examples from this data set.""" start = self._index_in_epoch self._index_in_epoch += batch_size if self._index_in_epoch > self._num_examples: # Finished epoch self._epochs_completed += 1 # Shuffle the data self._perm = np.arange(self._num_examples) np.random.shuffle(self._perm) # Start next epoch start = 0 self._index_in_epoch = batch_size assert batch_size <= self._num_examples end = self._index_in_epoch current_ids = self._perm[start:end] fake_ids = np.random.randint(self._num_examples, size=batch_size) #fake_ids = torch.rand((batch_size, )) * self._num_examples #fake_ids = fake_ids.long() collision_flag = ( self._class_id[current_ids] == self._class_id[fake_ids]) fake_ids[collision_flag] =\ (fake_ids[collision_flag] + random.randrange(100, 200)) % self._num_examples sampled_images = [] for i in current_ids: sampled_images += [self._images[i]] sampled_wrong_images = [] for i in fake_ids: sampled_wrong_images += [self._images[i]] #sampled_wrong_images = self._images[fake_ids, :, :, :] #sampled_images = sampled_images.astype(np.float32) #sampled_wrong_images = sampled_wrong_images.astype(np.float32) #sampled_images = sampled_images * (2. / 255) - 1. #sampled_wrong_images = sampled_wrong_images * (2. / 255) - 1. sampled_images, sampled_lr_images = self.transform(sampled_images) sampled_wrong_images, sampled_lr_wrong_images = self.transform( sampled_wrong_images) ret_list = [ sampled_images, sampled_lr_images, sampled_wrong_images, sampled_lr_wrong_images ] if self._embeddings is not None: filenames = [self._filenames[i] for i in current_ids] class_id = [self._class_id[i] for i in current_ids] sampled_embeddings, sampled_captions = \ self.sample_embeddings(self._embeddings, current_ids, filenames, class_id, window) sampled_embeddings = wrap_Variable(sampled_embeddings) ret_list.append(sampled_embeddings) ret_list.append(sampled_captions) else: ret_list.append(None) ret_list.append(None) if self._labels is not None: ret_list.append(self._labels[current_ids]) else: ret_list.append(None) return ret_list