Ejemplo n.º 1
0
 def __init__(self,
              xlabel=None,
              ylabel=None,
              legend=None,
              xlim=None,
              ylim=None,
              xscale='linear',
              yscale='linear',
              fmts=('-', 'm--', 'g-.', 'r:'),
              nrows=1,
              ncols=1,
              figsize=(3.5, 2.5)):
     # 增量的绘制多条线
     if legend is None:
         legend = []
     d2l.use_svg_display()
     self.fig, self.axes = d2l.plt.subplots(nrows, ncols, figsize=figsize)
     if nrows * ncols == 1:
         self.axes = [
             self.axes,
         ]
     # 使用lambda函数捕获参数
     self.config_axes = lambda: d2l.set_axes(self.axes[
         0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)
     self.X, self.Y, self.fmts = None, None, fmts
Ejemplo n.º 2
0
def show_heatmaps(matrices,
                  xlabel,
                  ylabel,
                  titles=None,
                  figsize=(2.5, 2.5),
                  cmap='Reds'):
    d2l.use_svg_display()
    num_rows, num_cols = matrices.shape[0], matrices.shape[1]
    fig, axes = d2l.plt.subplots(num_rows,
                                 num_cols,
                                 figsize=figsize,
                                 sharex=True,
                                 sharey=True,
                                 squeeze=False)
    for i, (row_axes, row_matrices) in enumerate(zip(axes, matrices)):
        for j, (ax, matrix) in enumerate(zip(row_axes, row_matrices)):
            pcm = ax.imshow(matrix.detach().numpy(), cmap=cmap)
            if i == num_rows - 1:
                ax.set_xlabel(xlabel)
            if j == 0:
                ax.set_ylabel(ylabel)
            if titles:
                ax.set_title(titles[j])
    fig.colorbar(pcm, ax=axes, shrink=0.6)
    plt.show()
Ejemplo n.º 3
0
 def __init__(self, xlabel=None, ylabel=None, legend=None, xlim=None,
              ylim=None, xscale='linear', yscale='linear',
              fmts=('-', 'm--', 'g-.', 'r:'), nrows=1, ncols=1,
              figsize=(3.5, 2.5)):
     # Incrementally plot multiple lines
     if legend is None:
         legend = []
     d2l.use_svg_display()
     self.fig, self.axes = d2l.plt.subplots(nrows, ncols, figsize=figsize)
     if nrows * ncols == 1:
         self.axes = [self.axes, ]
     # Use a lambda function to capture arguments
     self.config_axes = lambda: d2l.set_axes(
         self.axes[0], xlabel, ylabel, xlim, ylim, xscale, yscale, legend)
     self.X, self.Y, self.fmts = None, None, fmts
Ejemplo n.º 4
0







%matplotlib inline
from d2l import torch as d2l
import torch
import torchvision
from torchvision import transforms
from torch.utils import data

d2l.use_svg_display()




# `ToTensor` converts the image data from PIL type to 32-bit floating point
# tensors. It divides all numbers by 255 so that all pixel values are between
# 0 and 1
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(
    root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(
    root="../data", train=False, transform=trans, download=True)