コード例 #1
0
ファイル: DAGH_flip.py プロジェクト: EricKing19/ADSH
def _dataset():
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    transformations = transforms.Compose([
        transforms.Scale(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(), normalize
    ])

    dset_database = dp.DatasetProcessingCIFAR_10('data/CIFAR-10',
                                                 'database_img.txt',
                                                 'database_label.txt',
                                                 transformations)
    dset_test = dp.DatasetProcessingCIFAR_10('data/CIFAR-10', 'test_img.txt',
                                             'test_label.txt', transformations)
    num_database, num_test = len(dset_database), len(dset_test)

    def load_label(filename, DATA_DIR):
        path = os.path.join(DATA_DIR, filename)
        fp = open(path, 'r')
        labels = [x.strip() for x in fp]
        fp.close()
        return torch.LongTensor(list(map(int, labels)))

    testlabels = load_label('test_label.txt', 'data/CIFAR-10')
    databaselabels = load_label('database_label.txt', 'data/CIFAR-10')

    testlabels = encoding_onehot(testlabels)
    databaselabels = encoding_onehot(databaselabels)

    dsets = (dset_database, dset_test)
    nums = (num_database, num_test)
    labels = (databaselabels, testlabels)
    return nums, dsets, labels
コード例 #2
0
def create_dataset(dataset_name):
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    transformations = transforms.Compose([
        transforms.Scale(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(), normalize
    ])
    if dataset_name == 'NUS-WIDE':
        dset_test = dp.DatasetProcessingNUS_WIDE('data/NUS-WIDE',
                                                 'test_img.txt',
                                                 transformations)
        return dset_test
    if dataset_name == 'CIFAR-10':
        dset_test = dp.DatasetProcessingCIFAR_10('data/CIFAR-10',
                                                 'test_img.txt',
                                                 transformations)
        return dset_test
    if dataset_name == 'Project':
        if not os.path.exists('dcodes/adch-project-48bits-record.pkl'):
            record = {}
            dset_database = dp.DatasetProcessingPorject(
                'data/Project', 'database_img.txt', transformations)
            databaseloader = DataLoader(dset_database,
                                        batch_size=1,
                                        shuffle=False,
                                        num_workers=4)
            model = cnn_model.CNNNet('resnet50', 48)
            model.load_state_dict(
                torch.load('dict/adch-nuswide-48bits.pth',
                           map_location=torch.device('cpu')))
            model.eval()
            rB = encode(model, databaseloader, 4985, 48)
            record['rB'] = rB
            with open('dcodes/adch-project-48bits-record.pkl', 'wb') as fp:
                pickle.dump(record, fp)
        dset_test = dp.DatasetProcessingPorject('data/Project', 'test_img.txt',
                                                transformations)
        return dset_test