def main(): hp = Hyperparameters() print('Load data...') data = np.load(hp.data_pp_dir + 'data_arrays_' + hp.gender + '.npz') df_index_code = feather.read_dataframe(hp.data_pp_dir + 'df_index_code_' + hp.gender + '.feather') print('Train on each fold...') for fold in range(hp.num_folds): for swap in range(2): print('Fold: {} Swap: {}'.format(fold, swap)) idx = (data['fold'][:, fold] == swap) x = data['x'][idx] time = data['time'][idx] event = data['event'][idx] codes = data['codes'][idx] month = data['month'][idx] diagt = data['diagt'][idx] if not hp.redundant_predictors: cols_list = load_obj(hp.data_pp_dir + 'cols_list.pkl') x = x[:, [cols_list.index(i) for i in hp.reduced_col_list]] sort_idx, case_idx, max_idx_control = sort_and_case_indices( x, time, event) x, time, event = x[sort_idx], time[sort_idx], event[sort_idx] codes, month, diagt = codes[sort_idx], month[sort_idx], diagt[ sort_idx] print('Create data loaders and tensors...') case = utils.TensorDataset(torch.from_numpy(x[case_idx]), torch.from_numpy(time[case_idx]), torch.from_numpy(max_idx_control), torch.from_numpy(codes[case_idx]), torch.from_numpy(month[case_idx]), torch.from_numpy(diagt[case_idx])) x = torch.from_numpy(x) time = torch.from_numpy(time) event = torch.from_numpy(event) codes = torch.from_numpy(codes) month = torch.from_numpy(month) diagt = torch.from_numpy(diagt) for trial in range(hp.num_trials): print('Trial: {}'.format(trial)) # Create batch queues trn_loader = utils.DataLoader(case, batch_size=hp.batch_size, shuffle=True, drop_last=True) print('Train...') # Neural Net hp.model_name = str(trial) + '_' + datetime.now().strftime( '%Y%m%d_%H%M%S_%f') + '.pt' num_input = x.shape[1] + 1 if hp.nonprop_hazards else x.shape[1] net = NetRNNFinal(num_input, df_index_code.shape[0] + 1, hp).to(hp.device) #+1 for zero padding criterion = CoxPHLoss().to(hp.device) optimizer = optim.Adam(net.parameters(), lr=hp.learning_rate) for epoch in range(hp.max_epochs): trn(trn_loader, x, codes, month, diagt, net, criterion, optimizer, hp) if hp.redundant_predictors: torch.save( net.state_dict(), hp.log_dir + 'fold_' + str(fold) + '_' + str(swap) + '/' + hp.model_name) else: torch.save( net.state_dict(), hp.log_dir + 'fold_' + str(fold) + '_' + str(swap) + '_no_redundancies/' + hp.model_name) print('Done')
def objective(trial, data, df_index_code): hp = Hyperparameters(trial) #hp = Hyperparameters() print(trial.params) idx_trn = (data['fold'] != 99) x_trn = data['x'][idx_trn] time_trn = data['time'][idx_trn] event_trn = data['event'][idx_trn] codes_trn = data['codes'][idx_trn] month_trn = data['month'][idx_trn] diagt_trn = data['diagt'][idx_trn] idx_val = (data['fold'] == 99) x_val = data['x'][idx_val] time_val = data['time'][idx_val] event_val = data['event'][idx_val] codes_val = data['codes'][idx_val] month_val = data['month'][idx_val] diagt_val = data['diagt'][idx_val] # could move this outside objective function for efficiency sort_idx_trn, case_idx_trn, max_idx_control_trn = sort_and_case_indices( x_trn, time_trn, event_trn) sort_idx_val, case_idx_val, max_idx_control_val = sort_and_case_indices( x_val, time_val, event_val) x_trn, time_trn, event_trn = x_trn[sort_idx_trn], time_trn[ sort_idx_trn], event_trn[sort_idx_trn] codes_trn, month_trn, diagt_trn = codes_trn[sort_idx_trn], month_trn[ sort_idx_trn], diagt_trn[sort_idx_trn] x_val, time_val, event_val = x_val[sort_idx_val], time_val[ sort_idx_val], event_val[sort_idx_val] codes_val, month_val, diagt_val = codes_val[sort_idx_val], month_val[ sort_idx_val], diagt_val[sort_idx_val] ####################################################################################################### print('Create data loaders and tensors...') case_trn = utils.TensorDataset(torch.from_numpy(x_trn[case_idx_trn]), torch.from_numpy(time_trn[case_idx_trn]), torch.from_numpy(max_idx_control_trn), torch.from_numpy(codes_trn[case_idx_trn]), torch.from_numpy(month_trn[case_idx_trn]), torch.from_numpy(diagt_trn[case_idx_trn])) case_val = utils.TensorDataset(torch.from_numpy(x_val[case_idx_val]), torch.from_numpy(time_val[case_idx_val]), torch.from_numpy(max_idx_control_val), torch.from_numpy(codes_val[case_idx_val]), torch.from_numpy(month_val[case_idx_val]), torch.from_numpy(diagt_val[case_idx_val])) x_trn, x_val = torch.from_numpy(x_trn), torch.from_numpy(x_val) time_trn, time_val = torch.from_numpy(time_trn), torch.from_numpy(time_val) event_trn, event_val = torch.from_numpy(event_trn), torch.from_numpy( event_val) codes_trn, codes_val = torch.from_numpy(codes_trn), torch.from_numpy( codes_val) month_trn, month_val = torch.from_numpy(month_trn), torch.from_numpy( month_val) diagt_trn, diagt_val = torch.from_numpy(diagt_trn), torch.from_numpy( diagt_val) # Create batch queues trn_loader = utils.DataLoader(case_trn, batch_size=hp.batch_size, shuffle=True, drop_last=True) val_loader = utils.DataLoader(case_val, batch_size=hp.batch_size, shuffle=False, drop_last=False) print('Train...') # Neural Net hp.model_name = str(trial.number) + '_' + hp.model_name num_input = x_trn.shape[1] + 1 if hp.nonprop_hazards else x_trn.shape[1] net = NetRNNFinal(num_input, df_index_code.shape[0] + 1, hp).to(hp.device) #+1 for zero padding criterion = CoxPHLoss().to(hp.device) optimizer = optim.Adam(net.parameters(), lr=hp.learning_rate) best, num_bad_epochs = 100., 0 for epoch in range(1000): trn(trn_loader, x_trn, codes_trn, month_trn, diagt_trn, net, criterion, optimizer, hp) loss_val = val(val_loader, x_val, codes_val, month_val, diagt_val, net, criterion, epoch, hp) # early stopping if loss_val < best: print( '############### Saving good model ###############################' ) torch.save(net.state_dict(), hp.log_dir + hp.model_name) best = loss_val num_bad_epochs = 0 else: num_bad_epochs += 1 if num_bad_epochs == hp.patience: break # pruning trial.report(best, epoch) if trial.should_prune(): raise optuna.TrialPruned() print('Done') return best