示例#1
0
def dataset_summary(data_loader):
    """Create a histogram of class membership distribution within a dataset.

    It is important to examine our training, validation, and test
    datasets, to make sure that they are balanced.
    """
    msglogger.info("Analyzing dataset:")
    print_frequency = 50
    for batch, (input, label_batch) in enumerate(data_loader):
        try:
            all_labels = np.append(all_labels, distiller.to_np(label_batch))
        except NameError:
            all_labels = distiller.to_np(label_batch)
        if (batch + 1) % print_frequency == 0:
            # progress indicator
            print("batch: %d" % batch)

    hist = np.histogram(all_labels, bins=np.arange(1000 + 1))
    nclasses = len(hist[0])
    for data_class, size in enumerate(hist[0]):
        msglogger.info("\tClass {} = {}".format(data_class, size))
    msglogger.info("Dataset contains {} items".format(len(
        data_loader.sampler)))
    msglogger.info("Found {} classes".format(nclasses))
    msglogger.info("Average: {} samples per class".format(np.mean(hist[0])))
示例#2
0
def test_threshold_mask():
    # Create a 4-D tensor of 1s
    a = torch.ones(3, 64, 32, 32)
    # Change one element
    a[1, 4, 17, 31] = 0.2
    # Create and apply a mask
    mask = distiller.threshold_mask(a, threshold=0.3)
    assert np.sum(distiller.to_np(mask)) == (distiller.volume(a) - 1)
    assert mask[1, 4, 17, 31] == 0
    assert common.almost_equal(distiller.sparsity(mask), 1/distiller.volume(a))