def getTrainDataAndTestData(self): ds = { 'train': DS.CIFAR_10_DS(self.src + 'train/', 'train_img.txt', 'train_label.txt'), 'test': DS.CIFAR_10_DS(self.src + 'test/', 'test_img.txt', 'test_label.txt') } labels = { 'train': Pre.loadLabels('train_label.txt', self.src + 'train/'), 'test': Pre.loadLabels('test_label.txt', self.src + 'test/'), } trainData = { 'num': len(ds['train']), 'dataloader': DataLoader(ds['train'], batch_size=self.batch_size, shuffle=True, num_workers=4), 'label': labels['train'], 'one_hots': Pre.getOnehotCode(labels['train'], self.num_of_classes) } testData = { 'num': len(ds['test']), 'dataloader': DataLoader(ds['test'], batch_size=self.batch_size, shuffle=True, num_workers=4), 'label': labels['test'], 'one_hots': Pre.getOnehotCode(labels['test'], self.num_of_classes) } return trainData, testData
def getDataBaseData(self): ds = DS.CIFAR_10_DS(self.src + 'database/', 'database_img.txt', 'database_label.txt') labels = Pre.loadLabels('database_label.txt', self.src + 'database/'), databaseData = { 'num': len(ds), 'dataloader': DataLoader(ds, batch_size=self.batch_size, shuffle=True, num_workers=4), 'label': labels, 'one_hots': Pre.getOnehotCode(labels, self.num_of_classes) } return databaseData