def get_templates(file_name, oversample, template_opt='spindle', cam_target=2, num_patients=40, num_ch=2, before=5., after=5., task='all', try2load=True, gap_norm_opt='batch_norm', cuda_id=0, random=False, normalize_signals=False, noise_inter=0): device = get_device(cuda_id) features_subset = [] if 'freq' in file_name.lower(): features_subset += ['frequency'] if 'num_spindles' in file_name.lower(): features_subset += ['num_spindles'] elif 'spindle' in file_name.lower(): features_subset += ['spindle'] assert template_opt in ['spindle', 'activation', 'sw', 'rem', 'emg'] low_sp = 'low' in file_name feature_opt, signal_len, one_slice, dataset_dir = run_params(file_name, features_subset, def_feature_opt='HSIC+Concat', task=task) fs = 80 if 'ds' in file_name.lower() else 125 file_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'saved_models', file_name) num_classes = 2 if 'rem' in task.lower() else 5 print(f'Sample Frequency: {fs}') print(f'Num Classes: {num_classes}') random_str = 'random_' if random else '' save_name = f'{random_str}{template_opt}_template_class{cam_target}_{file_name}_{num_patients}patients' file_path = os.path.join('plot_results', save_name) if try2load and os.path.exists(file_path): with open(file_path, 'rb') as f: all_templates = pkl.load(f) else: filter_stage = cam_target if isinstance(cam_target, int) else None test_loader = init_datasets(num_patients=num_patients, dataset_dir=dataset_dir, batch_size=1, conv_d=1, features_subset=features_subset, one_slice=one_slice, modes=['test'], random_flag=False, num_ch=num_ch, low_sp=low_sp, filter_stage=filter_stage, task=task, normalize_signals=normalize_signals, oversample=oversample) model = HSICClassifier(num_classes=num_classes, signal_len=signal_len, feature_opt=feature_opt, in_channels=num_ch, feature_len=test_loader.dataset.feature_len, gap_norm_opt=gap_norm_opt) model.load_state_dict(torch.load(os.path.join(file_dir, f'{file_name}_params.pkl'), map_location='cpu')) model.to(device) model.eval() all_templates = [] with torch.no_grad(): for batch_idx, (signal, label, signal_name, features) in enumerate(tqdm(test_loader)): signal, features = signal.to(device), features.to(device) _, cam, _ = model(signal, features) signal = test_loader.dataset.unnormalize_signal(signal) signal = signal.cpu().numpy().reshape(-1, signal.shape[-1]) cam = np.squeeze(cam).cpu().numpy() if isinstance(cam_target, int): cam = cam[cam_target, :] else: cam = cam[label, :] signal = np.squeeze(signal) if template_opt == 'spindle': templates = get_spindle_templates(signal, fs=fs, cam=cam, random=random, sec_before=before, sec_after=after, num_ch=num_ch, noise_inter=noise_inter) if template_opt == 'activation': if num_ch == 2: signal = signal[0, :] # TODO!! templates = get_activation_templates(cam, signal, sec_before=before, sec_after=after, fs=fs) if template_opt == 'sw': templates = get_sw_templates(cam, signal, sec_before=before, sec_after=after, fs=fs, num_ch=num_ch, random=random, noise_inter=noise_inter) if template_opt == 'rem': eog = test_loader.dataset.get_eog(signal_name) templates = get_rem_templates(cam=cam, signal_eog=eog, sec_before=before, sec_after=after, fs_eeg=fs, random=random, noise_inter=noise_inter) if template_opt == 'emg': signal_emg = test_loader.dataset.get_emg(signal_name).squeeze() templates = get_emg_onset_templates(cam=cam, signal_emg=signal_emg, sec_before=before, sec_after=after, fs_eeg=fs, random=random) all_templates += templates all_templates = np.vstack(all_templates).T # num_templates = all_templates.shape[0] # Normalize templates # for i in range(num_templates): # cam_i = all_templates[:, i] # if ((cam_i - cam_i.mean()) != 0).sum() < 5: # continue # if max(cam_i) != min(cam_i): # cam_i = (cam_i - min(cam_i)) / (max(cam_i) - min(cam_i)) # all_templates[:, i] = cam_i if random is not None: save_name = os.path.join('noise', f'{save_name}_figure_data_noise={noise_inter}') with open(os.path.join('plot_results', save_name), 'wb') as f: pkl.dump(all_templates, f, protocol=pkl.HIGHEST_PROTOCOL) return all_templates
# experiment parameters file_name = "1slice_DS_frequency_lambda20*0.25_lsb_remNrem_20_73" task = 'rem_nrem' cuda_id = 0 features_subset = ['frequency'] balanced_dataset = True # training parameters batch_size = 32 num_epochs = 40 lr = 0.0003 rep_size = 512 feature_opt = run_params(features_subset) torch.manual_seed(44) device = get_device(cuda_id) file_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'saved_models', file_name) train_loader, val_loader, test_loader = init_datasets( task=task, balanced_dataset=balanced_dataset, batch_size=batch_size, normalize_signals=True, features_subset=features_subset) main_model = HSICClassifier(num_classes=2, feature_opt=feature_opt, gap_norm_opt='batch_norm', feature_len=train_loader.dataset.feature_len, in_channels=2).to(device)