Пример #1
0
    logger.info('max_epochs : {}'.format(max_epochs))
    logger.info('cuda available : {}'.format(torch.cuda.is_available()))

    # Data
    logger.info('Axial AID with nn.conv2d masked')

    if not os.path.exists(data_directory + '/aid' + str(resize_size[0]) + '_hexa.h5'):
        logger.info('Create hexagonal AID dataset')
        img, _ = datasets.ImageFolder(data_directory + '/AID',
                                      transform=transforms.Compose([transforms.Resize(resize_size),
                                                                    transforms.ToTensor()]))[0]
        index_matrix = utils.square_to_hexagonal_index_matrix(img)
        aid = datasets.ImageFolder(data_directory + '/AID',
                                   transform=transforms.Compose([transforms.Resize(resize_size),
                                                                 transforms.ToTensor(),
                                                                 utils.SquareToHexa()]))
        with h5py.File(data_directory + '/aid' + str(resize_size[0]) + '_hexa.h5', 'w') as f:
            images = []
            labels = []
            for i in range(len(aid)):
                image, label = aid[i]
                images.append(image.numpy())
                labels.append(label)
            f.create_dataset('images', data=np.array(images))
            f.create_dataset('labels', data=np.array(labels))
            f.attrs['index_matrix'] = index_matrix
            f.attrs['class_names'] = np.array(aid.classes, dtype=h5py.special_dtype(vlen=str))

    # load hexagonal cifar
    f = h5py.File(data_directory + '/aid' + str(resize_size[0]) + '_hexa.h5', 'r')
    data = f['images'][()]
Пример #2
0
logger.addHandler(console_handler)
formatter_file = logging.Formatter('%(asctime)s [%(levelname)s] - %(message)s')
file_handler = logging.FileHandler('{}/{}/{}.log'.format(main_directory,
                                                         experiment_name,
                                                         experiment_name))
file_handler.setFormatter(formatter_file)
logger.addHandler(file_handler)

batch_size = 64
test_batch_size = 1000

train_set = datasets.MNIST(main_directory + '/../ext_data', train=True, download=True,
                           transform=transforms.Compose([
                               transforms.ToTensor(),
                               transforms.Normalize((0.1307,), (0.3081,)),
                               utils.SquareToHexa()
                           ]))
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(datasets.MNIST(main_directory + '/../ext_data', train=False,
                                                         transform=transforms.Compose([
                                                             transforms.ToTensor(),
                                                             transforms.Normalize((0.1307,), (0.3081,)),
                                                             utils.SquareToHexa()
                                                         ])),
                                          batch_size=test_batch_size, shuffle=True)

device = torch.device("cuda")

# Plot a resampled image to check
img, _ = datasets.MNIST(main_directory + '/../ext_data', train=True, download=True,
                        transform=transforms.Compose([
Пример #3
0
    # Experiment parameters

    logger.info('batch_size : {}'.format(batch_size))
    logger.info('max_epochs : {}'.format(max_epochs))
    logger.info('cuda available : {}'.format(torch.cuda.is_available()))

    # Data
    if hexa:
        camera_layout = 'Hex'
        logger.info('Hexagonal CIFAR')
        img, _ = datasets.CIFAR10(data_directory, train=True, download=True, transform=transforms.ToTensor())[0]
        index_matrix = utils.square_to_hexagonal_index_matrix(img)

        if not os.path.exists(data_directory + '/cifar10.hdf5'):
            train_set = datasets.CIFAR10(data_directory, train=True, download=True,
                                         transform=transforms.Compose([transforms.ToTensor(), utils.SquareToHexa()]))
            with h5py.File(data_directory + '/cifar10.hdf5', 'w') as f:
                images = []
                labels = []
                for i in range(len(train_set)):
                    image, label = train_set[i]
                    images.append(image.numpy())
                    labels.append(label)
                f.create_dataset('images', data=np.array(images))
                f.create_dataset('labels', data=np.array(labels))
                f.attrs['index_matrix'] = index_matrix
        if not os.path.exists(data_directory + '/cifar10_test.hdf5'):
            test_set = datasets.CIFAR10(data_directory, train=False,
                                        transform=transforms.Compose([transforms.ToTensor(), utils.SquareToHexa()]))
            with h5py.File(data_directory + '/cifar10_test.hdf5', 'w') as f:
                images = []