def _dataset(): normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) transformations = transforms.Compose([ transforms.Scale(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize ]) dset_database = dp.DatasetProcessingCIFAR_10('data/CIFAR-10', 'database_img.txt', 'database_label.txt', transformations) dset_test = dp.DatasetProcessingCIFAR_10('data/CIFAR-10', 'test_img.txt', 'test_label.txt', transformations) num_database, num_test = len(dset_database), len(dset_test) def load_label(filename, DATA_DIR): path = os.path.join(DATA_DIR, filename) fp = open(path, 'r') labels = [x.strip() for x in fp] fp.close() return torch.LongTensor(list(map(int, labels))) testlabels = load_label('test_label.txt', 'data/CIFAR-10') databaselabels = load_label('database_label.txt', 'data/CIFAR-10') testlabels = encoding_onehot(testlabels) databaselabels = encoding_onehot(databaselabels) dsets = (dset_database, dset_test) nums = (num_database, num_test) labels = (databaselabels, testlabels) return nums, dsets, labels
def create_dataset(dataset_name): normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) transformations = transforms.Compose([ transforms.Scale(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize ]) if dataset_name == 'NUS-WIDE': dset_test = dp.DatasetProcessingNUS_WIDE('data/NUS-WIDE', 'test_img.txt', transformations) return dset_test if dataset_name == 'CIFAR-10': dset_test = dp.DatasetProcessingCIFAR_10('data/CIFAR-10', 'test_img.txt', transformations) return dset_test if dataset_name == 'Project': if not os.path.exists('dcodes/adch-project-48bits-record.pkl'): record = {} dset_database = dp.DatasetProcessingPorject( 'data/Project', 'database_img.txt', transformations) databaseloader = DataLoader(dset_database, batch_size=1, shuffle=False, num_workers=4) model = cnn_model.CNNNet('resnet50', 48) model.load_state_dict( torch.load('dict/adch-nuswide-48bits.pth', map_location=torch.device('cpu'))) model.eval() rB = encode(model, databaseloader, 4985, 48) record['rB'] = rB with open('dcodes/adch-project-48bits-record.pkl', 'wb') as fp: pickle.dump(record, fp) dset_test = dp.DatasetProcessingPorject('data/Project', 'test_img.txt', transformations) return dset_test