def dataset_summary(data_loader): """Create a histogram of class membership distribution within a dataset. It is important to examine our training, validation, and test datasets, to make sure that they are balanced. """ msglogger.info("Analyzing dataset:") print_frequency = 50 for batch, (input, label_batch) in enumerate(data_loader): try: all_labels = np.append(all_labels, distiller.to_np(label_batch)) except NameError: all_labels = distiller.to_np(label_batch) if (batch + 1) % print_frequency == 0: # progress indicator print("batch: %d" % batch) hist = np.histogram(all_labels, bins=np.arange(1000 + 1)) nclasses = len(hist[0]) for data_class, size in enumerate(hist[0]): msglogger.info("\tClass {} = {}".format(data_class, size)) msglogger.info("Dataset contains {} items".format(len( data_loader.sampler))) msglogger.info("Found {} classes".format(nclasses)) msglogger.info("Average: {} samples per class".format(np.mean(hist[0])))
def test_threshold_mask(): # Create a 4-D tensor of 1s a = torch.ones(3, 64, 32, 32) # Change one element a[1, 4, 17, 31] = 0.2 # Create and apply a mask mask = distiller.threshold_mask(a, threshold=0.3) assert np.sum(distiller.to_np(mask)) == (distiller.volume(a) - 1) assert mask[1, 4, 17, 31] == 0 assert common.almost_equal(distiller.sparsity(mask), 1/distiller.volume(a))