Esempio n. 1
0
def plot_activations(net, x, y_label, layer_plot_sizes):
    """Plots forward activations in a neural network
    Params
        net (torch.nn): neural network
        x (tensor): single data point of 4 dimensions: [batch, channel, height, width]
        y_label (string): used for plotting and file_names
        layer_plot_sizes (dict: {layer_num: [(layer_input_shape): (subplots)]
    Output:
        images in directory "activation-imgs/{y_label}/...". You must create this directory
    """
    output = dict()
    output[0] = x
    for i in range(len(net)):
        output[i + 1] = net[i](output[i])

    for i in range(len(output)):
        d2l.show_images(output[i].reshape(
            layer_plot_sizes[i][0]).cpu().detach().numpy(),
                        num_rows=layer_plot_sizes[i][1][0],
                        num_cols=layer_plot_sizes[i][1][1])
        if i == 13:  # net only has 12 layers
            plt.savefig(
                f'activation-imgs/{y_label}/{y_label}-layer-{i}-{net[i-1].__class__.__name__}-{layer_plot_sizes[i][0]}'
            )
        else:
            plt.savefig(
                f'activation-imgs/{y_label}/{y_label}-layer-{i}-{net[i].__class__.__name__}-{layer_plot_sizes[i][0]}'
            )
Esempio n. 2
0
def predict_ch3(net, test_iter, n=6):
    for X, y in test_iter:
        break
    trues = d2l.get_fashion_mnist_labels(y)
    preds = d2l.get_fashion_mnist_labels(net(X).argmax(axis=1))
    titles = [true + '\n' + pred for true, pred in zip(trues, preds)]
    d2l.show_images(X[0:n].reshape((n, 28, 28)), 1, n, titles=titles[0:n])
Esempio n. 3
0
def apply(img, aug, num_rows=2, num_cols=4, scale=1.5):
    Y = [aug(img) for _ in range(num_rows * num_cols)]
    d2l.show_images(Y, num_rows, num_cols, scale=scale)
trainer = torch.optim.SGD(net.parameters(), lr=lr, weight_decay=wd)
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices)


def predict(img):
    X = test_iter.dataset.normalize_image(img).unsqueeze(0)
    pred = net(X.to(devices[0])).argmax(dim=1)
    return pred.reshape(pred.shape[1], pred.shape[2])


def label2image(pred):
    colormap = torch.tensor(d2l.VOC_COLORMAP, device=devices[0])
    X = pred.long()
    return colormap[X, :]


voc_dir = d2l.download_extract('voc2012', 'VOCdevkit/VOC2012')
test_images, test_labels = d2l.read_voc_images(voc_dir, False)
n, imgs = 4, []
for i in range(n):
    crop_rect = (0, 0, 320, 480)
    X = torchvision.transforms.functional.crop(test_images[i], *crop_rect)
    pred = label2image(predict(X))
    imgs += [
        X.permute(1, 2, 0),
        pred.cpu(),
        torchvision.transforms.functional.crop(test_labels[i],
                                               *crop_rect).permute(1, 2, 0)
    ]
d2l.show_images(imgs[::3] + imgs[1::3] + imgs[2::3], 3, n, scale=2)