def train(cnn, epochs=80, learn_rate=0.001, batch_size=100, gpu=True): """ Train a regression CNN. Note that you do not need this function. Included for reference. """ if gpu: cnn.cuda() # Set up L2 loss criterion = nn.MSELoss() optimizer = torch.optim.Adam(cnn.parameters(), lr=learn_rate) # Loading & transforming data (x_train, y_train), (x_test, y_test) = load_cifar10() train_rgb, train_grey = process(x_train, y_train) test_rgb, test_grey = process(x_test, y_test) print("Beginning training ...") for epoch in range(epochs): # Train the Model cnn.train() # Change model to 'train' mode for i, (xs, ys) in enumerate(get_batch(train_grey, train_rgb, batch_size)): images, labels = get_torch_vars(xs, ys, gpu) # Forward + Backward + Optimize optimizer.zero_grad() outputs = cnn(images) loss = criterion(outputs, labels) loss.backward() optimizer.step() print('Epoch [%d/%d], Loss: %.4f' % (epoch + 1, epochs, loss.data[0])) # Evaluate the model cnn.eval() # Change model to 'eval' mode (BN uses moving mean/var). losses = [] for i, (xs, ys) in enumerate(get_batch(test_grey, test_rgb, batch_size)): images, labels = get_torch_vars(xs, ys, gpu) outputs = cnn(images) val_loss = criterion(outputs, labels) losses.append(val_loss.data[0]) val_loss = np.mean(losses) print('Epoch [%d/%d], Val Loss: %.4f' % (epoch + 1, epochs, val_loss)) # Save the Trained Model torch.save(cnn.state_dict(), 'regression_cnn_k%d_f%d.pkl' % ( args.kernel, args.num_filters))
parser.add_argument('-f', '--filters', default=32, type=int, help="Base number of convolution filters") parser.add_argument('-c', '--colors', default='colors/color_kmeans24_cat7.npy', help="Discrete color clusters to use") args = parser.parse_args() # LOAD THE COLORS CATEGORIES colors = np.load(args.colors, encoding='latin1',allow_pickle=True)[0] num_colors = np.shape(colors)[0] # Load the data first for consistency print("Loading data...") npr.seed(0) (x_train, y_train), (x_test, y_test) = load_cifar10() test_rgb, test_grey = process(x_test, y_test) test_rgb_cat = get_rgb_cat(test_rgb, colors) # LOAD THE MODEL if args.model == "CNN": cnn = CNN(args.kernel, args.filters, num_colors) elif args.model == "UNet": cnn = UNet(args.kernel, args.filters, num_colors) else: # model == "DUNet": cnn = DilatedUNet(args.kernel, args.filters, num_colors) print("Loading checkpoint...") cnn.load_state_dict(torch.load(args.checkpoint, map_location=lambda storage, loc: storage)) # Take the idnex of the test image id = args.index