def plot_new_samples(run_id, model_dir, trained): dump_dir = "/home/akajal/WatChMaL/VAE/dumps/" + run_id + "/" model_status = "trained" if trained is True else "untrained" np_arr_path = dump_dir + "samples/" + model_dir + "_" + model_status + ".npz" np_arr = np.load(np_arr_path) np_samples, np_labels, np_energies = np_arr["samples"], np_arr["predicted_labels"], np_arr["predicted_energies"] np_samples = np_samples.reshape(-1, np_samples.shape[2], np_samples.shape[3], np_samples.shape[4]) np_labels = np_labels.reshape(-1, 1) np_energies = np_energies.reshape(-1, 1) i, j = random.randint(0, np_labels.shape[0]-1), random.randint(0, np_labels.shape[0]-1) plot_utils.plot_actual_vs_recon(np_samples[i], np_samples[j], label_dict[np_labels[i].item()], np_energies[i].item(), label_dict[np_labels[j].item()], np_energies[j].item(), show_plot=True) plot_utils.plot_charge_hist(np_samples[i], np_samples[j], 0, num_bins=200) plot_utils.plot_charge_hist(np_samples, np_samples, 0, num_bins=200)
def plot_old_events(run_id, iteration, mode): dump_dir = "/home/akajal/WatChMaL/VAE/dumps/" + run_id + "/" if mode is "validation": np_arr_path = dump_dir + "val_iteration_" + str(iteration) + ".npz" else: np_arr_path = dump_dir + "iteration_" + str(iteration) + ".npz" # Load the numpy array np_arr = np.load(np_arr_path) np_event, np_recon, np_labels, np_energies = np_arr["events"], np_arr[ "recon"], np_arr["labels"], np_arr["energies"] i = random.randint(0, np_labels.shape[0] - 1) plot_utils.plot_actual_vs_recon(np_event[i], np_recon[i], label_dict[np_labels[i]], np_energies[i].item(), label_dict[np_labels[i]], np_energies[i].item(), show_plot=True) plot_utils.plot_charge_hist(torch.tensor(np_event).permute(0, 2, 3, 1).numpy(), np_recon, iteration, num_bins=200)
def plot_samples(run_id, model_dir, trained): dump_dir = "/home/akajal/WatChMaL/VAE/dumps/" + run_id + "/" model_status = "trained" if trained is True else "untrained" np_arr_path = dump_dir + "samples/" + model_dir + "_" + model_status + "_samples.npy" np_arr = np.load(np_arr_path, allow_pickle=True) i, j = random.randint(0, np_arr.shape[0]-1), random.randint(0, np_arr.shape[0]-1) plot_utils.plot_actual_vs_recon(np_arr[i][0][0], np_arr[j][0][0], label_dict[np_arr[i][1].item()], np_arr[i][2][0], show_plot=True) plot_utils.plot_charge_hist(np_arr[i][0][0], np_arr[j][0][0], 0, num_bins=200)