def get_radio_ml_loader(batch_size, train, **kwargs): data_dir = kwargs['data_dir'] min_snr = kwargs.get('min_snr', 6) max_snr = kwargs.get('max_snr', 30) per_h5_frac = kwargs.get('per_h5_frac', 0.5) train_frac = kwargs.get('train_frac', 0.9) per_sample_frac = kwargs.get('per_sample_frac', 1.0) normalize = kwargs.get('normalize', True) skip_1 = kwargs.get('skip_1', False) fake_height = kwargs.get('fake_height', False) classes = kwargs.get('classes', 24) dataset = RadioMLDataset(data_dir, train, normalize=normalize, fake_height=fake_height, min_snr=min_snr, max_snr=max_snr, per_h5_frac=per_h5_frac, train_frac=train_frac, skip_1=skip_1, per_sample_frac=per_sample_frac, classes=classes) identifier = 'train' if train else 'test' print('[%s] dataset size: %d' % (identifier, len(dataset))) loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=train) loader.name = 'RadioML_{}'.format(identifier) return loader
def get_radio_ml_loader_2016(batch_size, X, Y, train, normalize): dataset = RadioMLDataset2016(X, Y, normalize=normalize) identifier = 'train' if train else 'test' print('[%s] dataset size: %d' % (identifier, len(dataset))) loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=train) loader.name = 'RadioML2016_{}'.format(identifier) return loader
def get_mnist_loader(batch_size, train, taskid=0, **kwargs): transform = transforms.Compose([ transforms.Grayscale(), transforms.ToTensor(), transforms.Normalize((0.0, ), (1.0, )), transforms.Lambda(lambda x: x.view([28, 28])) ]) dataset = datasets.MNIST(root='./data', download=True, transform=transform, train=train) loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=train) loader.taskid = taskid loader.name = 'MNIST_{}'.format(taskid) loader.short_name = 'MNIST' return loader
def get_radio_ml_loader(batch_size, train, **kwargs): data_dir = kwargs['data_dir'] min_snr = kwargs.get('min_snr', 6) max_snr = kwargs.get('max_snr', 30) per_h5_frac = kwargs.get('per_h5_frac', 0.5) train_frac = kwargs.get('train_frac', 0.9) dataset = RadioMLDataset(data_dir, train, normalize=False, min_snr=min_snr, max_snr=max_snr, per_h5_frac=per_h5_frac, train_frac=train_frac) identifier = 'train' if train else 'test' print('[%s] dataset size: %d' % (identifier, len(dataset))) loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=train) loader.name = 'RadioML_{}'.format(identifier) return loader