trainloader, testloader = dataloader.get_dataloader(dataset, batch_size) # We only handle square images for now. The image shape can be read from # the shape of the elements in the dataloader. The shape of each batch # in the dataloader is on the format: # (num_images, num_channels, width, height) # num_images is capped at batch_size. num_channels represent the number of # colors channels in the image: usually 1 (gray-scale) or 3 (colored). # Since we expect the width and height to be the same, we read the third # shape value (width) as our image size first_batch = next(iter(trainloader))[0] img_size = first_batch.shape[2] num_colors = first_batch.shape[1] net = ConvNet(code_size, img_size, num_colors, device).to(device) optimizer = optim.Adam(net.parameters(), lr=1e-3, weight_decay=1e-5) t_start = time.time() t_last = time.time() for epoch in range(epochs): train(trainloader, device) test(testloader, device) print( f'{time.time()-t_start:.1f}s\t{time.time()-t_last:.1f}s\tDone running epoch {epoch+1}' ) t_last = time.time() # Save the learned weights for later use torch.save(net, f'static/{dataset_name}_state.pth') # We want to sample some values sent to the decoder. The reason is that # we want to use this to define a range for each of the n nodes in the # code that we can use in the front end. We do this by sending a batch # through the encoder. code_input = net.encoder(first_batch).detach().numpy() np.save(f'static/{dataset_name}_code', code_input)