Ejemplo n.º 1
0
    logger.addHandler(file_handler)

    # Experiment parameters
    logger.info('batch_size : {}'.format(batch_size))
    logger.info('max_epochs : {}'.format(max_epochs))
    logger.info('cuda available : {}'.format(torch.cuda.is_available()))

    # Data
    if hexa:
        camera_layout = 'Hex'
        logger.info('Hexagonal CIFAR')
        img, _ = datasets.CIFAR10(data_directory,
                                  train=True,
                                  download=True,
                                  transform=transforms.ToTensor())[0]
        index_matrix = utils.square_to_hexagonal_index_matrix(img)

        if not os.path.exists(data_directory + '/cifar10.hdf5'):
            train_set = datasets.CIFAR10(data_directory,
                                         train=True,
                                         download=True,
                                         transform=transforms.Compose([
                                             transforms.ToTensor(),
                                             utils.SquareToHexa()
                                         ]))
            with h5py.File(data_directory + '/cifar10.hdf5', 'w') as f:
                images = []
                labels = []
                for i in range(len(train_set)):
                    image, label = train_set[i]
                    images.append(image.numpy())
Ejemplo n.º 2
0
 def test_square_to_hexagonal_index_matrix(self):
     torch.testing.assert_allclose(utils.square_to_hexagonal_index_matrix(self.square_image),
                                   self.square_to_hexagonal_index_matrix)