def evaluate(epoch, data_loader, model, split, early_stopper, device, train_step=None): if split == 'val': epoch_metrics = utils.EpochWriter("Val", regression, experiment) else: epoch_metrics = utils.EpochWriter("Test", regression, experiment) model.eval() running_loss = 0.0 running_loss_decomp = 0.0 running_loss_los = 0.0 running_loss_ihm = 0.0 running_loss_pheno = 0.0 running_loss_readmit = 0.0 running_loss_ltm = 0.0 tk = tqdm(data_loader, total=int(len(data_loader))) criterion = nn.BCEWithLogitsLoss() for i, data in enumerate(tk): if data is None: continue decomp_label, decomp_mask, los_label, los_mask, ihm_label, ihm_mask,\ pheno_label, readmit_label, readmit_mask, ltm_label, ltm_mask, num_valid_data = retrieve_data(data, device) if use_ts: ts = torch.from_numpy(data['time series']) ts = ts.permute(1,0,2).float().to(device) else: ts = None if use_text: texts, texts_weight_mat = text_embedding(embedding_layer, data, device) else: texts = None texts_weight_mat = None if use_tab: tab_dict = data['tab'] for cat in tab_dict: tab_dict[cat] = torch.from_numpy(tab_dict[cat]).long().to(device) else: tab_dict = None decomp_logits, los_logits, ihm_logits, pheno_logits, readmit_logits, ltm_logits = model(ts = ts, texts = texts, texts_weight_mat = texts_weight_mat,\ tab_dict = tab_dict ) loss_decomp = masked_weighted_cross_entropy_loss(None, decomp_logits, decomp_label, decomp_mask) loss_los = masked_weighted_cross_entropy_loss(los_class_weight, los_logits, los_label, los_mask) loss_ihm = masked_weighted_cross_entropy_loss(None,ihm_logits, ihm_label, ihm_mask) loss_pheno = criterion(pheno_logits, pheno_label) loss_readmit = masked_weighted_cross_entropy_loss(readmit_class_weight, readmit_logits, readmit_label, readmit_mask) loss_ltm = masked_weighted_cross_entropy_loss(None, ltm_logits, ltm_label, ltm_mask) losses = { 'decomp' : loss_decomp, 'ihm' : loss_ihm, 'los' : loss_los, 'pheno' : loss_pheno, 'readmit': loss_readmit, 'ltm' : loss_ltm, } loss = 0.0 for task in losses: loss += losses[task] * task_weight[task] running_loss += losses[target_task] * task_weight[target_task] running_loss_decomp += loss_decomp.item() * task_weight['decomp'] running_loss_los +=loss_los.item()* task_weight['los'] running_loss_ihm +=loss_ihm.item()* task_weight['ihm'] running_loss_pheno += loss_pheno.item()* task_weight['pheno'] running_loss_readmit += loss_readmit.item()*task_weight['readmit'] running_loss_ltm += loss_ltm.item()*task_weight['ltm'] m = nn.Softmax(dim=1) sigmoid = nn.Sigmoid() decomp_pred = (sigmoid(decomp_logits)[:,1]).cpu().detach().numpy() los_pred = m(los_logits).cpu().detach().numpy() ihm_pred = (sigmoid(ihm_logits)[:,1]).cpu().detach().numpy() pheno_pred = sigmoid(pheno_logits).cpu().detach().numpy() readmit_pred = m(readmit_logits).cpu().detach().numpy() ltm_pred = (sigmoid(ltm_logits)[:,1]).cpu().detach().numpy() #print(sigmoid(readmit_logits)) outputs = { 'decomp': {'pred': decomp_pred, 'label': decomp_label.cpu().detach().numpy(), 'mask': decomp_mask.cpu().detach().numpy()}, 'ihm': {'pred': ihm_pred, 'label': ihm_label.cpu().detach().numpy(), 'mask': ihm_mask.cpu().detach().numpy()}, 'los': {'pred': los_pred, 'label': los_label.cpu().detach().numpy(), 'mask': los_mask.cpu().detach().numpy()}, 'pheno': {'pred': pheno_pred, 'label': pheno_label.cpu().detach().numpy(), 'mask': None}, 'readmit': {'pred': readmit_pred, 'label':readmit_label.cpu().detach().numpy(), 'mask': readmit_mask.cpu().detach().numpy()}, 'ltm': {'pred': ltm_pred, 'label': ltm_label.cpu().detach().numpy(), 'mask': ltm_mask.cpu().detach().numpy()}, } epoch_metrics.cache(outputs, num_valid_data) if train_step is not None: xpoint = train_step else: xpoint = epoch+1 epoch_metrics.write(writer, xpoint) writer.add_scalar('{} loss'.format(split), running_loss/(i), xpoint) writer.add_scalar('{} decomp loss'.format(split), running_loss_decomp / (i), xpoint) writer.add_scalar('{} los loss'.format(split), running_loss_los / (i), xpoint) writer.add_scalar('{} ihm loss'.format(split), running_loss_ihm / (i), xpoint) writer.add_scalar('{} pheno loss'.format(split), running_loss_pheno / (i), xpoint) writer.add_scalar('{} readmit loss'.format(split), running_loss_readmit/ (i), xpoint) writer.add_scalar('{} ltm loss'.format(split), running_loss_ltm/ (i), xpoint) if split == 'val': early_stopper(running_loss/(i), model)
def train(epochs, train_data_loader, test_data_loader, early_stopper, model, optimizer, scheduler, device): criterion = nn.BCEWithLogitsLoss() crossentropyloss = nn.CrossEntropyLoss() model.to(device) train_b = 0 for epoch in range(epochs): print('Epoch {}/{}'.format(epoch+1, epochs)) print('-' * 50) model.train() running_loss =0.0 running_loss_decomp = 0.0 running_loss_los = 0.0 running_loss_ihm = 0.0 running_loss_pheno = 0.0 running_loss_readmit = 0.0 running_loss_ltm = 0.0 epoch_metrics = utils.EpochWriter("Train", regression, experiment) tk0 = tqdm(train_data_loader, total=int(len(train_data_loader))) for i, data in enumerate(tk0): #------------------------ retrive labels and masks per task -----------------# decomp_label, decomp_mask, los_label, los_mask, ihm_label, ihm_mask,\ pheno_label, readmit_label, readmit_mask, ltm_label, ltm_mask, num_valid_data = retrieve_data(data, device) #------------------------- only load modality if that modality is being used ----------------# if use_ts: ts = torch.from_numpy(data['time series']) ts = ts.permute(1,0,2).float().to(device) else: ts = None if use_text: texts, texts_weight_mat = text_embedding(embedding_layer, data, device) else: texts = None texts_weight_mat = None if use_tab: tab_dict = data['tab'] for cat in tab_dict: tab_dict[cat] = torch.from_numpy(tab_dict[cat]).long().to(device) else: tab_dict = None #---------------------------------- inferece for all tasks ----------------------------------------------------# decomp_logits, los_logits, ihm_logits, pheno_logits, readmit_logits, ltm_logits = model(ts = ts, texts = texts,\ texts_weight_mat = texts_weight_mat, tab_dict = tab_dict) #----------------------------------compute losses per task -----------------------------------------# loss_decomp = masked_weighted_cross_entropy_loss(None, decomp_logits, decomp_label, decomp_mask) loss_los = masked_weighted_cross_entropy_loss(los_class_weight, los_logits, los_label, los_mask) loss_ihm = masked_weighted_cross_entropy_loss(None,ihm_logits, ihm_label, ihm_mask) loss_pheno = criterion(pheno_logits, pheno_label) loss_readmit = masked_weighted_cross_entropy_loss(None, readmit_logits, readmit_label, readmit_mask) loss_ltm = masked_weighted_cross_entropy_loss(None, ltm_logits, ltm_label, ltm_mask) losses = { 'decomp' : loss_decomp, 'ihm' : loss_ihm, 'los' : loss_los, 'pheno' : loss_pheno, 'readmit': loss_readmit, 'ltm' : loss_ltm, } loss = 0.0 #------------------------- combine losses ----------------------------# for task in losses: #-------- uncertainty weighting -----------------# #prec = torch.exp(-log_var[task]) #losses[task] = torch.sum(losses[task] * prec + log_var[task], -1) #loss += torch.sum(losses[task] * prec + log_var[task], -1) #-------- end uncertainty weighting stuff -------# loss += losses[task] * task_weight[task] train_b+=1 optimizer.zero_grad() loss.backward() optimizer.step() #------------------------------- keep track of weighted per task losses ----------------# running_loss += loss.item() running_loss_decomp += loss_decomp.item() * task_weight['decomp'] running_loss_los +=loss_los.item()* task_weight['los'] running_loss_ihm +=loss_ihm.item()* task_weight['ihm'] running_loss_pheno += loss_pheno.item()* task_weight['pheno'] running_loss_readmit += loss_readmit.item()*task_weight['readmit'] running_loss_ltm += loss_ltm.item()*task_weight['ltm'] m = nn.Softmax(dim=1) sig = nn.Sigmoid() #-------------------------- compute metrics per task -------------# decomp_pred = (sig(decomp_logits)[:, 1]).cpu().detach().numpy() los_pred = m(los_logits).cpu().detach().numpy() ihm_pred = (sig(ihm_logits)[:, 1]).cpu().detach().numpy() pheno_pred = sig(pheno_logits).cpu().detach().numpy() readmit_pred = m(readmit_logits).cpu().detach().numpy() ltm_pred = (sig(ltm_logits)[:,1]).cpu().detach().numpy() #----------------------------- log metrics per task -------------# outputs = { 'decomp': {'pred': decomp_pred, 'label': decomp_label.cpu().detach().numpy(), 'mask': decomp_mask.cpu().detach().numpy()}, 'ihm': {'pred': ihm_pred, 'label': ihm_label.cpu().detach().numpy(), 'mask': ihm_mask.cpu().detach().numpy()}, 'los': {'pred': los_pred, 'label': los_label.cpu().detach().numpy(), 'mask': los_mask.cpu().detach().numpy()}, 'pheno': {'pred': pheno_pred, 'label': pheno_label.cpu().detach().numpy(), 'mask': None}, 'readmit': {'pred': readmit_pred, 'label':readmit_label.cpu().detach().numpy(), 'mask': readmit_mask.cpu().detach().numpy()}, 'ltm': {'pred': ltm_pred, 'label': ltm_label.cpu().detach().numpy(), 'mask': ltm_mask.cpu().detach().numpy()}, } epoch_metrics.cache(outputs, num_valid_data) #how often to log metrics interval = 500 if i %interval == interval-1: writer.add_scalar('training loss', running_loss/(interval -1), train_b) writer.add_scalar('decomp loss', running_loss_decomp / (interval -1), train_b) writer.add_scalar('los loss', running_loss_los / (interval -1), train_b) writer.add_scalar('ihm loss', running_loss_ihm / (interval -1), train_b) writer.add_scalar('pheno loss', running_loss_pheno / (interval-1), train_b) writer.add_scalar('readmit loss', running_loss_readmit / (interval -1), train_b) writer.add_scalar('ltm loss', running_loss_ltm / (interval-1), train_b) running_loss_decomp = 0.0 running_loss_los = 0.0 running_loss_ihm = 0.0 running_loss_pheno = 0.0 running_loss_readmit = 0.0 running_loss_ltm = 0.0 running_loss = 0.0 epoch_metrics.add() epoch_metrics.write(writer,train_b) #scheduler.step() evaluate(epoch, val_data_loader, model, 'val', early_stopper, device, train_b) #evaluate(epoch, test_data_loader, model, 'test', early_stopper, device, train_b) if early_stopper.early_stop: #stop training and eval on test set evaluate(epoch, test_data_loader, model, 'test', early_stopper, device, train_b) break
def evaluate(epoch, data_loader, model, device, train_step=None, split="Test"): embedding_layer = nn.Embedding(vectors.shape[0], vectors.shape[1]) embedding_layer.weight.data.copy_(torch.from_numpy(vectors)) embedding_layer.weight.requires_grad = False ''' if os.path.exists('./ckpt/multitasking_experiment_ts_text_union_cw.ckpt'): model.load_state_dict(torch.load('./ckpt/multitasking_experiment_ts_text_union_cw.ckpt')) ''' epoch_metrics = utils.EpochWriter(split, regression) model.to(device) model.eval() total_data_points = 0 running_loss = 0.0 running_loss_decomp = 0.0 running_loss_los = 0.0 running_loss_ihm = 0.0 running_loss_pheno = 0.0 tk = tqdm(data_loader, total=int(len(data_loader))) criterion = nn.BCEWithLogitsLoss() for i, data in enumerate(tk): # if i>=50: # break if data is None: continue num_valid_data = data['decomp label'].shape[0] total_data_points += num_valid_data name = data['name'][0] # print(name) ihm_mask = torch.from_numpy(np.array(data['ihm mask'])).long() ihm_mask = ihm_mask.to(device) ihm_label = np.array(data['ihm label']) ihm_label = ihm_label.reshape(-1, 1).squeeze(1) ihm_label = torch.from_numpy(ihm_label).long().to(device) decomp_mask = (torch.from_numpy(data['decomp mask'])).long().to(device) decomp_label = data['decomp label'] #num_valid_data = decomp_label.shape[0] decomp_label = decomp_label.reshape(-1, 1).squeeze(1) # (b*t,) decomp_label = torch.from_numpy(decomp_label).long().to(device) los_mask = torch.from_numpy(np.array( data['los mask'])).long().to(device) #print('los mask',los_mask) #np.save('./res/mask/los_mask_{}.npy'.format(name), los_mask) los_label = np.array(data['los label']) #np.save('./res/label/los_label_{}.npy'.format(name), los_label) los_label = los_label.reshape(-1, 1).squeeze(1) los_label = torch.from_numpy(los_label).to(device) pheno_label = torch.from_numpy(np.array( data['pheno label'])).float().to(device) if use_ts: ts = torch.from_numpy(data['time series']) ts = ts.permute(1, 0, 2).float().to(device) else: ts = None if use_wf: waveforms = torch.from_numpy(data['waveforms']).float() waveforms_weight_mat = torch.from_numpy( data['waveforms weight mat']).float() waveforms_weight_mat = waveforms_weight_mat.to(device) waveforms = waveforms.to(device) else: waveforms = None waveforms_weight_mat = None if use_text: texts = torch.from_numpy(data['texts']).to(torch.int64) texts_weight_mat = torch.from_numpy( data['texts weight mat']).float() texts_weight_mat = texts_weight_mat.to(device) texts = embedding_layer( texts) # [batch_size, num_docs, seq_len, emb_dim] texts = texts.to(device) if text_model.name == 'avg': texts = avg_emb(texts, texts_weight_mat) # t = ts.shape[0] # b = ts.shape[1] # texts = torch.rand(b*t, 768).float().to(device) # texts_weight_mat = None else: texts = None texts_weight_mat = None decomp_logits, los_logits, ihm_logits, pheno_logits = model(ts = ts, texts = texts,texts_weight_mat = texts_weight_mat,\ waveforms = waveforms, waveforms_weight_mat = waveforms_weight_mat) loss_decomp = masked_weighted_cross_entropy_loss( None, decomp_logits, decomp_label, decomp_mask) loss_los = masked_weighted_cross_entropy_loss(None, los_logits, los_label, los_mask) # loss_los = masked_mse_loss(los_logits, # los_label, # los_mask) loss_ihm = masked_weighted_cross_entropy_loss(None, ihm_logits, ihm_label, ihm_mask) loss_pheno = criterion(pheno_logits, pheno_label) losses = { 'decomp': loss_decomp, 'ihm': loss_ihm, 'los': loss_los, 'pheno': loss_pheno } loss = 0.0 for task in losses: loss += losses[task] * task_weight[task] running_loss += loss.item() running_loss_decomp += loss_decomp.item() * task_weight['decomp'] running_loss_los += loss_los.item() * task_weight['los'] running_loss_ihm += loss_ihm.item() * task_weight['ihm'] running_loss_pheno += loss_pheno.item() * task_weight['pheno'] m = nn.Softmax(dim=1) sigmoid = nn.Sigmoid() decomp_pred = (sigmoid(decomp_logits)[:, 1]).cpu().detach().numpy() #los_pred = torch.argmax(m(los_logits), dim=1).cpu().detach().numpy() los_pred = m(los_logits).cpu().detach().numpy() #np.save('./res/pred/los_pred_{}.npy'.format(name), los_pred) #los_pred = los_logits.cpu().detach().numpy() ihm_pred = (sigmoid(ihm_logits)[:, 1]).cpu().detach().numpy() pheno_pred = sigmoid(pheno_logits).cpu().detach().numpy() #print(pheno_pred.sum(dim=1)) outputs = { 'decomp': { 'pred': decomp_pred, 'label': decomp_label.cpu().detach().numpy(), 'mask': decomp_mask.cpu().detach().numpy() }, 'ihm': { 'pred': ihm_pred, 'label': ihm_label.cpu().detach().numpy(), 'mask': ihm_mask.cpu().detach().numpy() }, 'los': { 'pred': los_pred, 'label': los_label.cpu().detach().numpy(), 'mask': los_mask.cpu().detach().numpy() }, 'pheno': { 'pred': pheno_pred, 'label': pheno_label.cpu().detach().numpy(), 'mask': None }, } epoch_metrics.cache(outputs, num_valid_data) #unique_elements, counts_elements = np.unique(metric_los_kappa.y_pred, return_counts=True) #labels, counts = np.unique(metric_los_kappa.y_true, return_counts=True) if train_step is not None: xpoint = train_step else: xpoint = epoch + 1 epoch_metrics.write(writer, xpoint) writer.add_scalar('{} loss'.format(split), running_loss / (i), xpoint) writer.add_scalar('{} decomp loss'.format(split), running_loss_decomp / (i), xpoint) writer.add_scalar('{} los loss'.format(split), running_loss_los / (i), xpoint) writer.add_scalar('{} ihm loss'.format(split), running_loss_ihm / (i), xpoint) writer.add_scalar('{} pheno loss'.format(split), running_loss_pheno / (i), xpoint)
def train(epochs, train_data_loader, test_data_loader, model, optimizer, device): # class frequency for each decomp_weight = torch.FloatTensor([1.0214, 47.6688]) decomp_weight = decomp_weight / decomp_weight.sum() ihm_weight = torch.FloatTensor([1.1565, 7.3888]) ihm_weight = ihm_weight / ihm_weight.sum() # los_weight = torch.FloatTensor([ 66.9758, 30.3148, 13.7411, 6.8861, 4.8724, 4.8037, 5.7935, # 8.9295, 29.8249, 391.6768]) los_weight = torch.FloatTensor([1.6047, 3.8934, 8.3376]) #los_weight = los_weight/los_weight.sum() pheno_weight = torch.FloatTensor([ 19.2544, 55.1893, 40.1445, 12.8604, 30.7595, 31.4979, 19.9768, 57.2309, 15.4088, 12.8200, 43.2644, 21.3991, 14.2026, 9.8531, 15.3284, 57.1641, 31.0782, 46.4064, 81.0640, 102.7755, 47.5936, 29.6070, 22.7682, 28.8175, 52.8856 ]) pheno_weight = pheno_weight / pheno_weight.sum() pheno_weight = pheno_weight.to(device) criterion = nn.BCEWithLogitsLoss() embedding_layer = nn.Embedding(vectors.shape[0], vectors.shape[1]) embedding_layer.weight.data.copy_(torch.from_numpy(vectors)) embedding_layer.weight.requires_grad = False model.to(device) train_b = 0 running_loss_decomp = 0.0 running_loss_los = 0.0 running_loss_ihm = 0.0 running_loss_pheno = 0.0 running_loss = 0.0 for epoch in range(epochs): total_data_points = 0 print('Epoch {}/{}'.format(epoch + 1, epochs)) print('-' * 50) model.train() running_loss = 0.0 running_loss_decomp = 0.0 running_loss_los = 0.0 running_loss_ihm = 0.0 running_loss_pheno = 0.0 epoch_metrics = utils.EpochWriter("Train", regression) tk0 = tqdm(train_data_loader, total=int(len(train_data_loader))) ''' if os.path.exists('./ckpt/multitasking_experiment_ts_text_union_cw.ckpt'): model.load_state_dict(torch.load('./ckpt/multitasking_experiment_ts_text_union_cw.ckpt')) ''' for i, data in enumerate(tk0): if data is None: continue ihm_mask = torch.from_numpy(np.array(data['ihm mask'])) ihm_mask = ihm_mask.to(device) ihm_label = torch.from_numpy(np.array(data['ihm label'])).long() ihm_label = ihm_label.reshape(-1, 1).squeeze(1) ihm_label = ihm_label.to(device) ihm_weight = ihm_weight.to(device) decomp_mask = torch.from_numpy(data['decomp mask']) decomp_mask = decomp_mask.to(device) decomp_label = torch.from_numpy(data['decomp label']).long() # the num valid data is used in case the last batch is smaller than batch size num_valid_data = decomp_label.shape[0] total_data_points += num_valid_data decomp_label = decomp_label.reshape(-1, 1).squeeze(1) # (b*t,) decomp_label = decomp_label.to(device) decomp_weight = decomp_weight.to(device) los_mask = torch.from_numpy(np.array(data['los mask'])) los_mask = los_mask.to(device) los_label = torch.from_numpy(np.array(data['los label'])) los_label = los_label.reshape(-1, 1).squeeze(1) los_label = los_label.to(device) los_weight = los_weight.to(device) pheno_label = torch.from_numpy(np.array( data['pheno label'])).float() pheno_label = pheno_label.to(device) if use_ts: ts = torch.from_numpy(data['time series']) ts = ts.permute(1, 0, 2).float().to(device) else: ts = None if use_wf: waveforms = torch.from_numpy(data['waveforms']).float() waveforms_weight_mat = torch.from_numpy( data['waveforms weight mat']).float() waveforms_weight_mat = waveforms_weight_mat.to(device) waveforms = waveforms.to(device) else: waveforms = None waveforms_weight_mat = None if use_text: texts = torch.from_numpy(data['texts']).to(torch.int64) texts_weight_mat = torch.from_numpy( data['texts weight mat']).float() texts_weight_mat = texts_weight_mat.to(device) texts = embedding_layer( texts) # [batch_size, num_docs, seq_len, emb_dim] #print(texts.shape) texts = texts.to(device) #print(texts.shape) if text_model.name == 'avg': texts = avg_emb(texts, texts_weight_mat) # t = ts.shape[0] # b = ts.shape[1] # texts = torch.rand(b*t, 768).float().to(device) # texts_weight_mat = None else: texts = None texts_weight_mat = None decomp_logits, los_logits, ihm_logits, pheno_logits = model(ts = ts, texts = texts,\ texts_weight_mat = texts_weight_mat, waveforms = waveforms, waveforms_weight_mat =waveforms_weight_mat) loss_decomp = masked_weighted_cross_entropy_loss( None, decomp_logits, decomp_label, decomp_mask) #loss_los = masked_weighted_cross_entropy_loss(None, los_logits,los_label, los_mask) loss_los = masked_weighted_cross_entropy_loss( los_weight, los_logits, los_label, los_mask) #print(loss_los.item()) loss_ihm = masked_weighted_cross_entropy_loss( None, ihm_logits, ihm_label, ihm_mask) loss_pheno = criterion(pheno_logits, pheno_label) losses = { 'decomp': loss_decomp, 'ihm': loss_ihm, 'los': loss_los, 'pheno': loss_pheno } loss = 0.0 for task in losses: #loss += losses[task] prec = torch.exp(-log_var[task]) #losses[task] = torch.sum(losses[task] * prec + log_var[task], -1) loss += torch.sum(losses[task] * prec + log_var[task], -1) #loss += losses[task] * task_weight[task] #loss += losses[task] #loss = torch.mean(loss) #loss = loss_decomp*5+loss_ihm*2+loss_los*1+loss_pheno*2 #loss = loss_los train_b += 1 optimizer.zero_grad() loss.backward() optimizer.step() running_loss += loss.item() running_loss_decomp += loss_decomp.item() * task_weight['decomp'] running_loss_los += loss_los.item() * task_weight['los'] running_loss_ihm += loss_ihm.item() * task_weight['ihm'] running_loss_pheno += loss_pheno.item() * task_weight['pheno'] m = nn.Softmax(dim=1) sig = nn.Sigmoid() decomp_pred = (m(decomp_logits)[:, 1]).cpu().detach().numpy() #los_pred = torch.argmax(m(los_logits), dim=1).cpu().detach().numpy() los_pred = m(los_logits).cpu().detach().numpy() #los_pred = los_logits.cpu().detach().numpy() ihm_pred = (m(ihm_logits)[:, 1]).cpu().detach().numpy() pheno_pred = sig(pheno_logits).cpu().detach().numpy() outputs = { 'decomp': { 'pred': decomp_pred, 'label': decomp_label.cpu().detach().numpy(), 'mask': decomp_mask.cpu().detach().numpy() }, 'ihm': { 'pred': ihm_pred, 'label': ihm_label.cpu().detach().numpy(), 'mask': ihm_mask.cpu().detach().numpy() }, 'los': { 'pred': los_pred, 'label': los_label.cpu().detach().numpy(), 'mask': los_mask.cpu().detach().numpy() }, 'pheno': { 'pred': pheno_pred, 'label': pheno_label.cpu().detach().numpy(), 'mask': None }, } epoch_metrics.cache(outputs, num_valid_data) interval = 500 if i % interval == interval - 1: writer.add_scalar('training loss', running_loss / (interval - 1), train_b) writer.add_scalar('decomp loss', running_loss_decomp / (interval - 1), train_b) writer.add_scalar('los loss', running_loss_los / (interval - 1), train_b) writer.add_scalar('ihm loss', running_loss_ihm / (interval - 1), train_b) writer.add_scalar('pheno loss', running_loss_pheno / (interval - 1), train_b) ''' print('loss decomp', running_loss_decomp/interval) print('loss ihm', running_loss_ihm/interval) print('loss los', running_loss_los/interval) print('loss pheno', running_loss_pheno/interval) print('epoch {} , training loss is {:.3f}'.format(epoch+1, running_loss_los/interval)) ''' running_loss_decomp = 0.0 running_loss_los = 0.0 running_loss_ihm = 0.0 running_loss_pheno = 0.0 running_loss = 0.0 epoch_metrics.add() for task in losses: print(task, torch.exp(-log_var[task])) epoch_metrics.write(writer, train_b) torch.save( model.state_dict(), os.path.join('./ckpt/', 'epoch{0}'.format(epoch) + model_weights)) evaluate(epoch, val_data_loader, model, device, train_b, "Val") evaluate(epoch, test_data_loader, model, device, train_b, "Test")
def evaluate(epoch, data_loader, model, split, early_stopper, device, train_step=None): aucroc_readmit = utils.AUCROCREADMIT() aucpr_readmit = utils.AUCPRREADMIT() cfm_readmit = utils.ConfusionMatrixReadmit() if split == 'val': epoch_metrics = utils.EpochWriter("Val", regression, experiment) else: epoch_metrics = utils.EpochWriter("Test", regression, experiment) model.to(device) model.eval() running_loss = 0.0 tk = tqdm(data_loader, total=int(len(data_loader))) criterion = nn.BCEWithLogitsLoss() for i, data in enumerate(tk): if data is None: continue decomp_label, decomp_mask, los_label, los_mask, ihm_label, ihm_mask,\ pheno_label, readmit_label, readmit_mask, num_valid_data = retrieve_data(data, device) if use_ts: ts = torch.from_numpy(data['time series']) ts = ts.permute(1, 0, 2).float().to(device) else: ts = None if use_text: texts = text_embedding(embedding_layer, data, device) else: texts = None readmit_logits = model(texts=texts) loss = masked_weighted_cross_entropy_loss(None, readmit_logits, readmit_label, readmit_mask) running_loss += loss.item() sigmoid = nn.Sigmoid() readmit_pred = (sigmoid(readmit_logits)[:, 1]).cpu().detach().numpy() readmit_label = readmit_label.cpu().detach().numpy() readmit_mask = readmit_mask.cpu().detach().numpy() aucpr_readmit.add(readmit_pred, readmit_label, readmit_mask) aucroc_readmit.add(readmit_pred, readmit_label, readmit_mask) cfm_readmit.add(readmit_pred, readmit_label, readmit_mask) print('readmission aucpr is {}'.format(aucpr_readmit.get())) print('readmission aucroc is {}'.format(aucroc_readmit.get())) print('readmission cfm is {}'.format(cfm_readmit.get())) if train_step is not None: xpoint = train_step else: xpoint = epoch + 1 writer.add_scalar('{} readmit loss'.format(split), running_loss / (i), xpoint) if split == 'val': early_stopper(running_loss_pheno / (i), model)
def train(epochs, train_data_loader, test_data_loader, early_stopper, model, optimizer, scheduler, device): criterion = nn.BCEWithLogitsLoss() aucroc_readmit = utils.AUCROCREADMIT() aucpr_readmit = utils.AUCPRREADMIT() cfm_readmit = utils.ConfusionMatrixReadmit() model.to(device) train_b = 0 for epoch in range(epochs): print('Epoch {}/{}'.format(epoch + 1, epochs)) print('-' * 50) model.train() running_loss = 0.0 epoch_metrics = utils.EpochWriter("Train", regression, experiment) tk0 = tqdm(train_data_loader, total=int(len(train_data_loader))) for i, data in enumerate(tk0): if data is None: continue decomp_label, decomp_mask, los_label, los_mask, ihm_label, ihm_mask,\ pheno_label, readmit_label, readmit_mask, num_valid_data = retrieve_data(data, device) if use_text: texts = text_embedding(embedding_layer, data, device) else: texts = None readmit_logits = model(texts=texts) loss = masked_weighted_cross_entropy_loss(None, readmit_logits, readmit_label, readmit_mask) train_b += 1 optimizer.zero_grad() loss.backward() optimizer.step() running_loss += loss.item() m = nn.Softmax(dim=1) sig = nn.Sigmoid() readmit_pred = (sig(readmit_logits)[:, 1]).cpu().detach().numpy() readmit_label = readmit_label.cpu().detach().numpy() if readmit_label is None: print('bad') readmit_mask = readmit_mask.cpu().detach().numpy() aucpr_readmit.add(readmit_pred, readmit_label, readmit_mask) aucroc_readmit.add(readmit_pred, readmit_label, readmit_mask) cfm_readmit.add(readmit_pred, readmit_label, readmit_mask) interval = 50 if i % interval == interval - 1: writer.add_scalar('training loss', running_loss / (interval - 1), train_b) print('readmission aucpr is {}'.format(aucpr_readmit.get())) print('readmission aucroc is {}'.format(aucroc_readmit.get())) print('readmission cfm is {}'.format(cfm_readmit.get())) #scheduler.step() #evaluate(epoch, val_data_loader, model, 'val', early_stopper, device, train_b) evaluate(epoch, test_data_loader, model, 'test', early_stopper, device, train_b)