class COCOData(data.Dataset): def __init__(self, **kwargs): self.stage = kwargs['stage'] self.coco_interface = kwargs['coco_interface'] # this returns the list of image objects, equal to the number of images of the relevant class(es) self.datalist = kwargs['datalist'] # load the list of the image self.ann_data = self.coco_interface.loadAnns(self.datalist) self.captions = [] self.image_ids = [] for i in range(len(self.ann_data)): self.captions.append(self.ann_data[i]["caption"]) self.image_ids.append(self.ann_data[i]["image_id"]) self.vocab = Vocabulary(Constants.word_threshold) if os.path.exists(Constants.vocab_file) & Constants.vocab_from_file: self.vocab.get_vocab() else: self.vocab.build_vocabulary(self.captions) # this method normalizes the image and converts it to Pytorch tensor # Here we use pytorch transforms functionality, and Compose them together, def transform(self, img, is_grayscale): # these mean values are for RGB!! t_ = None if Hyper.is_grayscale: t_ = transforms.Compose([ transforms.ToPILImage(), transforms.ToTensor(), transforms.Grayscale(num_output_channels=1), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 5]) ]) else: t_ = transforms.Compose([ transforms.ToPILImage(), transforms.ToTensor(), transforms.Lambda(lambda x: x.repeat(3, 1, 1)) if is_grayscale else NoneTransform(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) img = t_(img) # need this for the input in the model # returns image tensor (CxHxW) return img # download the image # return rgb image def load_img(self, idx): img_id = self.ann_data[idx]["image_id"] path = self.coco_interface.loadImgs(img_id) coco_url = path[0]["coco_url"] img = io.imread(coco_url) is_grayscale = self.check_if_grayscale(img) im = np.array(img) im = self.transform(im, is_grayscale) return im def check_if_grayscale(self, img): # Gayscale images have 2 dimensions only # RGB images have 3 dimensions and the third dimension is 3 if len(img.shape) == 3: if img.shape[2] == 3: return False return True def load_caption(self, idx): caption = self.ann_data[idx]['caption'] return caption # number of images def __len__(self): return len(self.datalist) # return image + mask def __getitem__(self, idx): img = self.load_img(idx) caption = self.load_caption(idx) numericalized_caption = [self.vocab.stoi[Constants.SOS]] numericalized_caption += self.vocab.numericalize(caption) numericalized_caption.append(self.vocab.stoi[Constants.EOS]) return img, torch.tensor(numericalized_caption) def test_interface_with_single_image(self, image_id): ann_ids = self.coco_interface.loadAnns(image_id) img, caption = self[image_id] self.coco_interface.showAnns(ann_ids) image = img.squeeze().permute(1, 2, 0) plt.imshow(image) plt.savefig("test.png") print("Image saved to test.png. The associated caption is: ") print(caption)