logger.info('max_epochs : {}'.format(max_epochs)) logger.info('cuda available : {}'.format(torch.cuda.is_available())) # Data logger.info('Axial AID with nn.conv2d masked') if not os.path.exists(data_directory + '/aid' + str(resize_size[0]) + '_hexa.h5'): logger.info('Create hexagonal AID dataset') img, _ = datasets.ImageFolder(data_directory + '/AID', transform=transforms.Compose([transforms.Resize(resize_size), transforms.ToTensor()]))[0] index_matrix = utils.square_to_hexagonal_index_matrix(img) aid = datasets.ImageFolder(data_directory + '/AID', transform=transforms.Compose([transforms.Resize(resize_size), transforms.ToTensor(), utils.SquareToHexa()])) with h5py.File(data_directory + '/aid' + str(resize_size[0]) + '_hexa.h5', 'w') as f: images = [] labels = [] for i in range(len(aid)): image, label = aid[i] images.append(image.numpy()) labels.append(label) f.create_dataset('images', data=np.array(images)) f.create_dataset('labels', data=np.array(labels)) f.attrs['index_matrix'] = index_matrix f.attrs['class_names'] = np.array(aid.classes, dtype=h5py.special_dtype(vlen=str)) # load hexagonal cifar f = h5py.File(data_directory + '/aid' + str(resize_size[0]) + '_hexa.h5', 'r') data = f['images'][()]
logger.addHandler(console_handler) formatter_file = logging.Formatter('%(asctime)s [%(levelname)s] - %(message)s') file_handler = logging.FileHandler('{}/{}/{}.log'.format(main_directory, experiment_name, experiment_name)) file_handler.setFormatter(formatter_file) logger.addHandler(file_handler) batch_size = 64 test_batch_size = 1000 train_set = datasets.MNIST(main_directory + '/../ext_data', train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)), utils.SquareToHexa() ])) train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True) test_loader = torch.utils.data.DataLoader(datasets.MNIST(main_directory + '/../ext_data', train=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)), utils.SquareToHexa() ])), batch_size=test_batch_size, shuffle=True) device = torch.device("cuda") # Plot a resampled image to check img, _ = datasets.MNIST(main_directory + '/../ext_data', train=True, download=True, transform=transforms.Compose([
# Experiment parameters logger.info('batch_size : {}'.format(batch_size)) logger.info('max_epochs : {}'.format(max_epochs)) logger.info('cuda available : {}'.format(torch.cuda.is_available())) # Data if hexa: camera_layout = 'Hex' logger.info('Hexagonal CIFAR') img, _ = datasets.CIFAR10(data_directory, train=True, download=True, transform=transforms.ToTensor())[0] index_matrix = utils.square_to_hexagonal_index_matrix(img) if not os.path.exists(data_directory + '/cifar10.hdf5'): train_set = datasets.CIFAR10(data_directory, train=True, download=True, transform=transforms.Compose([transforms.ToTensor(), utils.SquareToHexa()])) with h5py.File(data_directory + '/cifar10.hdf5', 'w') as f: images = [] labels = [] for i in range(len(train_set)): image, label = train_set[i] images.append(image.numpy()) labels.append(label) f.create_dataset('images', data=np.array(images)) f.create_dataset('labels', data=np.array(labels)) f.attrs['index_matrix'] = index_matrix if not os.path.exists(data_directory + '/cifar10_test.hdf5'): test_set = datasets.CIFAR10(data_directory, train=False, transform=transforms.Compose([transforms.ToTensor(), utils.SquareToHexa()])) with h5py.File(data_directory + '/cifar10_test.hdf5', 'w') as f: images = []