def cross_validate(x, y, peak_names, output_file_path): kf = KFold(n_splits=10, shuffle=True) pred_all = [] corr_all = [] peak_order = [] for train_index, test_index in kf.split(x): train_data, eval_data = x[train_index, :, :], x[test_index, :, :] train_labels, eval_labels = y[train_index, :], y[test_index, :] train_names, eval_name = peak_names[train_index], peak_names[ test_index] # Data loader train_dataset = torch.utils.data.TensorDataset( torch.from_numpy(train_data), torch.from_numpy(train_labels)) train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=False) eval_dataset = torch.utils.data.TensorDataset( torch.from_numpy(eval_data), torch.from_numpy(eval_labels)) eval_loader = torch.utils.data.DataLoader(dataset=eval_dataset, batch_size=batch_size, shuffle=False) # create model model = aitac.ConvNet(num_classes, num_filters).to(device) # Loss and optimizer criterion = aitac.pearson_loss optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) # train model model, best_loss = aitac.train_model(train_loader, eval_loader, model, device, criterion, optimizer, num_epochs, output_file_path) # Predict on test set predictions, max_activations, max_act_index = aitac.test_model( eval_loader, model, device) # plot the correlations histogram correlations = plot_utils.plot_cors(eval_labels, predictions, output_file_path) pred_all.append(predictions) corr_all.append(correlations) peak_order.append(eval_name) pred_all = np.vstack(pred_all) corr_all = np.hstack(corr_all) peak_order = np.hstack(peak_order) return pred_all, corr_all, peak_order
shuffle=False) valid_dataset = torch.utils.data.TensorDataset(torch.from_numpy(valid_data), torch.from_numpy(valid_labels)) valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset, batch_size=batch_size, shuffle=False) eval_dataset = torch.utils.data.TensorDataset(torch.from_numpy(eval_data), torch.from_numpy(eval_labels)) eval_loader = torch.utils.data.DataLoader(dataset=eval_dataset, batch_size=batch_size, shuffle=False) # create model model = aitac.ConvNet(num_classes, num_filters).to(device) # weights from model with 300 filters checkpoint = torch.load("../models/" + original_model + ".ckpt") #indices of filters in original model filters = np.loadtxt('../data/filter_set99_index.txt') #load new weights into model checkpoint2 = model.state_dict() #copy original model weights into new model for i, (layer_name, layer_weights) in enumerate(checkpoint.items()): # for all first layer weights take subset if i < 2: