Пример #1
0
 trainloader, testloader = dataloader.get_dataloader(dataset, batch_size)
 # We only handle square images for now. The image shape can be read from
 #   the shape of the elements in the dataloader. The shape of each batch
 #   in the dataloader is on the format:
 #   (num_images, num_channels, width, height)
 #   num_images is capped at batch_size. num_channels represent the number of
 #   colors channels in the image: usually 1 (gray-scale) or 3 (colored).
 #   Since we expect the width and height to be the same, we read the third
 #   shape value (width) as our image size
 first_batch = next(iter(trainloader))[0]
 img_size = first_batch.shape[2]
 num_colors = first_batch.shape[1]
 net = ConvNet(code_size, img_size, num_colors, device).to(device)
 optimizer = optim.Adam(net.parameters(), lr=1e-3, weight_decay=1e-5)
 t_start = time.time()
 t_last = time.time()
 for epoch in range(epochs):
     train(trainloader, device)
     test(testloader, device)
     print(
         f'{time.time()-t_start:.1f}s\t{time.time()-t_last:.1f}s\tDone running epoch {epoch+1}'
     )
     t_last = time.time()
 # Save the learned weights for later use
 torch.save(net, f'static/{dataset_name}_state.pth')
 # We want to sample some values sent to the decoder. The reason is that
 #   we want to use this to define a range for each of the n nodes in the
 #   code that we can use in the front end. We do this by sending a batch
 #   through the encoder.
 code_input = net.encoder(first_batch).detach().numpy()
 np.save(f'static/{dataset_name}_code', code_input)