def plot_activations(net, x, y_label, layer_plot_sizes): """Plots forward activations in a neural network Params net (torch.nn): neural network x (tensor): single data point of 4 dimensions: [batch, channel, height, width] y_label (string): used for plotting and file_names layer_plot_sizes (dict: {layer_num: [(layer_input_shape): (subplots)] Output: images in directory "activation-imgs/{y_label}/...". You must create this directory """ output = dict() output[0] = x for i in range(len(net)): output[i + 1] = net[i](output[i]) for i in range(len(output)): d2l.show_images(output[i].reshape( layer_plot_sizes[i][0]).cpu().detach().numpy(), num_rows=layer_plot_sizes[i][1][0], num_cols=layer_plot_sizes[i][1][1]) if i == 13: # net only has 12 layers plt.savefig( f'activation-imgs/{y_label}/{y_label}-layer-{i}-{net[i-1].__class__.__name__}-{layer_plot_sizes[i][0]}' ) else: plt.savefig( f'activation-imgs/{y_label}/{y_label}-layer-{i}-{net[i].__class__.__name__}-{layer_plot_sizes[i][0]}' )
def predict_ch3(net, test_iter, n=6): for X, y in test_iter: break trues = d2l.get_fashion_mnist_labels(y) preds = d2l.get_fashion_mnist_labels(net(X).argmax(axis=1)) titles = [true + '\n' + pred for true, pred in zip(trues, preds)] d2l.show_images(X[0:n].reshape((n, 28, 28)), 1, n, titles=titles[0:n])
def apply(img, aug, num_rows=2, num_cols=4, scale=1.5): Y = [aug(img) for _ in range(num_rows * num_cols)] d2l.show_images(Y, num_rows, num_cols, scale=scale)
trainer = torch.optim.SGD(net.parameters(), lr=lr, weight_decay=wd) d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs, devices) def predict(img): X = test_iter.dataset.normalize_image(img).unsqueeze(0) pred = net(X.to(devices[0])).argmax(dim=1) return pred.reshape(pred.shape[1], pred.shape[2]) def label2image(pred): colormap = torch.tensor(d2l.VOC_COLORMAP, device=devices[0]) X = pred.long() return colormap[X, :] voc_dir = d2l.download_extract('voc2012', 'VOCdevkit/VOC2012') test_images, test_labels = d2l.read_voc_images(voc_dir, False) n, imgs = 4, [] for i in range(n): crop_rect = (0, 0, 320, 480) X = torchvision.transforms.functional.crop(test_images[i], *crop_rect) pred = label2image(predict(X)) imgs += [ X.permute(1, 2, 0), pred.cpu(), torchvision.transforms.functional.crop(test_labels[i], *crop_rect).permute(1, 2, 0) ] d2l.show_images(imgs[::3] + imgs[1::3] + imgs[2::3], 3, n, scale=2)