Ejemplo n.º 1
0
def evaluate(epoch, data_loader, model, split, early_stopper, device,  train_step=None):
    
    if split == 'val':
        epoch_metrics = utils.EpochWriter("Val", regression, experiment)

    else:
        epoch_metrics = utils.EpochWriter("Test", regression, experiment)


    model.eval()
    

    running_loss = 0.0
    running_loss_decomp = 0.0
    running_loss_los = 0.0
    running_loss_ihm = 0.0
    running_loss_pheno = 0.0
    running_loss_readmit = 0.0
    running_loss_ltm = 0.0
    tk = tqdm(data_loader, total=int(len(data_loader)))
    criterion = nn.BCEWithLogitsLoss()

   
    for i, data in enumerate(tk):

        if data is None:
                continue
        
        decomp_label, decomp_mask, los_label, los_mask, ihm_label, ihm_mask,\
                 pheno_label, readmit_label, readmit_mask, ltm_label, ltm_mask, num_valid_data = retrieve_data(data, device)
        
        if use_ts:
                ts = torch.from_numpy(data['time series'])
                ts = ts.permute(1,0,2).float().to(device)
        else:
            ts = None

        
        if use_text:
            texts, texts_weight_mat = text_embedding(embedding_layer, data, device)
        else:
            texts = None
            texts_weight_mat = None
        
        if use_tab:
            tab_dict = data['tab']
            for cat in tab_dict:
                tab_dict[cat] = torch.from_numpy(tab_dict[cat]).long().to(device)
        else:
            tab_dict = None

        decomp_logits, los_logits, ihm_logits, pheno_logits, readmit_logits, ltm_logits = model(ts = ts, texts = texts, texts_weight_mat = texts_weight_mat,\
            tab_dict = tab_dict
        )

        loss_decomp = masked_weighted_cross_entropy_loss(None, 
                                                        decomp_logits, 
                                                        decomp_label, 
                                                        decomp_mask)
        loss_los = masked_weighted_cross_entropy_loss(los_class_weight, 
                                                      los_logits,
                                                      los_label, 
                                                      los_mask)

        loss_ihm = masked_weighted_cross_entropy_loss(None,ihm_logits, 
                                                      ihm_label, 
                                                      ihm_mask)
        loss_pheno = criterion(pheno_logits, pheno_label)
        loss_readmit = masked_weighted_cross_entropy_loss(readmit_class_weight, readmit_logits, readmit_label, readmit_mask)
        loss_ltm = masked_weighted_cross_entropy_loss(None, ltm_logits, ltm_label, ltm_mask)
        
        losses = {
            'decomp' : loss_decomp,
            'ihm'    : loss_ihm,
            'los'    : loss_los,
            'pheno'  : loss_pheno,
            'readmit': loss_readmit,
            'ltm'    : loss_ltm,
        }
        loss = 0.0

        
        for task in losses:
            loss += losses[task] * task_weight[task]
    
        running_loss += losses[target_task] * task_weight[target_task]
        running_loss_decomp += loss_decomp.item() * task_weight['decomp']
        running_loss_los +=loss_los.item()* task_weight['los']
        running_loss_ihm +=loss_ihm.item()* task_weight['ihm']
        running_loss_pheno += loss_pheno.item()* task_weight['pheno']
        running_loss_readmit += loss_readmit.item()*task_weight['readmit']
        running_loss_ltm += loss_ltm.item()*task_weight['ltm']


        m = nn.Softmax(dim=1)
        sigmoid = nn.Sigmoid()
        
        decomp_pred = (sigmoid(decomp_logits)[:,1]).cpu().detach().numpy()
        los_pred = m(los_logits).cpu().detach().numpy()
        ihm_pred = (sigmoid(ihm_logits)[:,1]).cpu().detach().numpy()
        pheno_pred = sigmoid(pheno_logits).cpu().detach().numpy()
        readmit_pred = m(readmit_logits).cpu().detach().numpy()
        ltm_pred = (sigmoid(ltm_logits)[:,1]).cpu().detach().numpy()
        #print(sigmoid(readmit_logits))

        outputs = {
            'decomp': {'pred': decomp_pred,
                        'label': decomp_label.cpu().detach().numpy(),
                        'mask': decomp_mask.cpu().detach().numpy()},
            'ihm': {'pred': ihm_pred,
                        'label': ihm_label.cpu().detach().numpy(),
                        'mask': ihm_mask.cpu().detach().numpy()},
            'los': {'pred': los_pred,
                        'label': los_label.cpu().detach().numpy(),
                        'mask': los_mask.cpu().detach().numpy()},
            'pheno': {'pred': pheno_pred,
                        'label': pheno_label.cpu().detach().numpy(),
                        'mask': None},
            'readmit': {'pred': readmit_pred,
                            'label':readmit_label.cpu().detach().numpy(),
                            'mask': readmit_mask.cpu().detach().numpy()},
            'ltm': {'pred': ltm_pred,
                        'label': ltm_label.cpu().detach().numpy(),
                        'mask': ltm_mask.cpu().detach().numpy()},
            
        }

        epoch_metrics.cache(outputs, num_valid_data)
   


    if train_step is not None:
        xpoint = train_step
    else:
        xpoint = epoch+1


    epoch_metrics.write(writer, xpoint)
    writer.add_scalar('{} loss'.format(split),
                running_loss/(i),
                xpoint)
    writer.add_scalar('{} decomp loss'.format(split),
                running_loss_decomp / (i),
                xpoint)
    writer.add_scalar('{} los loss'.format(split),
                running_loss_los / (i),
                xpoint)
    writer.add_scalar('{} ihm loss'.format(split),
                running_loss_ihm / (i),
                xpoint)
    writer.add_scalar('{} pheno loss'.format(split),
                running_loss_pheno / (i),
                xpoint)
    writer.add_scalar('{} readmit loss'.format(split),
                running_loss_readmit/ (i),
                xpoint)
    writer.add_scalar('{} ltm loss'.format(split),
                running_loss_ltm/ (i),
                xpoint)

    if split == 'val':
        early_stopper(running_loss/(i), model)
Ejemplo n.º 2
0
def train(epochs, train_data_loader, test_data_loader, early_stopper, model, optimizer, scheduler, device):
    
    criterion = nn.BCEWithLogitsLoss()
    crossentropyloss = nn.CrossEntropyLoss()
    
    
    model.to(device)
    train_b = 0
    
    
    
    for epoch in range(epochs):
       
        print('Epoch {}/{}'.format(epoch+1, epochs))
        print('-' * 50)
        model.train()
        
        running_loss =0.0
        running_loss_decomp = 0.0
        running_loss_los = 0.0
        running_loss_ihm = 0.0
        running_loss_pheno = 0.0
        running_loss_readmit = 0.0
        running_loss_ltm = 0.0

        epoch_metrics = utils.EpochWriter("Train", regression, experiment)

        tk0 = tqdm(train_data_loader, total=int(len(train_data_loader)))

        for i, data in enumerate(tk0):
            
            #------------------------ retrive labels and masks per task -----------------#
            decomp_label, decomp_mask, los_label, los_mask, ihm_label, ihm_mask,\
                 pheno_label, readmit_label, readmit_mask, ltm_label, ltm_mask, num_valid_data = retrieve_data(data, device)
            
            #------------------------- only load modality if that modality is being used ----------------#
            if use_ts:
                ts = torch.from_numpy(data['time series'])
                ts = ts.permute(1,0,2).float().to(device)
            else:
                ts = None
          
            
            if use_text:
                texts, texts_weight_mat = text_embedding(embedding_layer, data, device)
            else:
                texts = None
                texts_weight_mat = None
            if use_tab:
                tab_dict = data['tab']
                for cat in tab_dict:
                    tab_dict[cat] = torch.from_numpy(tab_dict[cat]).long().to(device)
            else:
                tab_dict = None

            #---------------------------------- inferece for all tasks ----------------------------------------------------#
            decomp_logits, los_logits, ihm_logits, pheno_logits, readmit_logits, ltm_logits = model(ts = ts, texts = texts,\
             texts_weight_mat = texts_weight_mat, tab_dict = tab_dict)

            #----------------------------------compute losses per task -----------------------------------------#
            loss_decomp = masked_weighted_cross_entropy_loss(None, decomp_logits, decomp_label, decomp_mask)
            loss_los = masked_weighted_cross_entropy_loss(los_class_weight, los_logits, los_label, los_mask)
            loss_ihm = masked_weighted_cross_entropy_loss(None,ihm_logits, ihm_label, ihm_mask)
            loss_pheno = criterion(pheno_logits, pheno_label)
            loss_readmit = masked_weighted_cross_entropy_loss(None, readmit_logits, readmit_label, readmit_mask)
            loss_ltm = masked_weighted_cross_entropy_loss(None, ltm_logits, ltm_label, ltm_mask)

            losses = {
                'decomp' : loss_decomp,
                'ihm'    : loss_ihm,
                'los'    : loss_los,
                'pheno'  : loss_pheno,
                'readmit': loss_readmit,
                'ltm'    : loss_ltm, 
            }

            loss = 0.0


            #------------------------- combine losses ----------------------------#
            for task in losses:
                #-------- uncertainty weighting -----------------#
                #prec = torch.exp(-log_var[task])
                #losses[task] = torch.sum(losses[task] * prec + log_var[task], -1)
                #loss += torch.sum(losses[task] * prec + log_var[task], -1)
                #-------- end uncertainty weighting stuff -------#
 
                loss += losses[task] * task_weight[task]
            
            train_b+=1

            
            
            optimizer.zero_grad()
            loss.backward()
           
            optimizer.step()


            #------------------------------- keep track of weighted per task losses ----------------#
            running_loss += loss.item()
            running_loss_decomp += loss_decomp.item() * task_weight['decomp']
            running_loss_los +=loss_los.item()* task_weight['los']
            running_loss_ihm +=loss_ihm.item()* task_weight['ihm']
            running_loss_pheno += loss_pheno.item()* task_weight['pheno']
            running_loss_readmit += loss_readmit.item()*task_weight['readmit']
            running_loss_ltm += loss_ltm.item()*task_weight['ltm']


            m = nn.Softmax(dim=1)
            sig = nn.Sigmoid()


            #-------------------------- compute metrics per task -------------#
            decomp_pred = (sig(decomp_logits)[:, 1]).cpu().detach().numpy()
            los_pred = m(los_logits).cpu().detach().numpy()
            ihm_pred = (sig(ihm_logits)[:, 1]).cpu().detach().numpy()
            pheno_pred = sig(pheno_logits).cpu().detach().numpy()
            readmit_pred = m(readmit_logits).cpu().detach().numpy()
            ltm_pred = (sig(ltm_logits)[:,1]).cpu().detach().numpy()
            


            #----------------------------- log metrics per task -------------#
            outputs = {
                'decomp': {'pred': decomp_pred,
                           'label': decomp_label.cpu().detach().numpy(),
                           'mask': decomp_mask.cpu().detach().numpy()},
                'ihm': {'pred': ihm_pred,
                           'label': ihm_label.cpu().detach().numpy(),
                           'mask': ihm_mask.cpu().detach().numpy()},
                'los': {'pred': los_pred,
                           'label': los_label.cpu().detach().numpy(),
                           'mask': los_mask.cpu().detach().numpy()},
                'pheno': {'pred': pheno_pred,
                           'label': pheno_label.cpu().detach().numpy(),
                           'mask': None},
                'readmit': {'pred': readmit_pred,
                            'label':readmit_label.cpu().detach().numpy(),
                            'mask': readmit_mask.cpu().detach().numpy()},
                'ltm': {'pred': ltm_pred,
                        'label': ltm_label.cpu().detach().numpy(),
                        'mask': ltm_mask.cpu().detach().numpy()},
            }
            epoch_metrics.cache(outputs, num_valid_data)
            #how often to log metrics
            interval = 500
            

            if i %interval == interval-1:
                
                writer.add_scalar('training loss',
                            running_loss/(interval -1),
                            train_b)
                writer.add_scalar('decomp loss',
                            running_loss_decomp / (interval -1),
                            train_b)
                writer.add_scalar('los loss',
                            running_loss_los / (interval -1),
                            train_b)
                writer.add_scalar('ihm loss',
                            running_loss_ihm / (interval -1),
                            train_b)
                writer.add_scalar('pheno loss',
                            running_loss_pheno / (interval-1),
                            train_b)
                writer.add_scalar('readmit loss',
                            running_loss_readmit / (interval -1),
                            train_b)

                writer.add_scalar('ltm loss',
                            running_loss_ltm / (interval-1),
                            train_b)
                
                
                running_loss_decomp = 0.0
                running_loss_los = 0.0
                running_loss_ihm = 0.0
                running_loss_pheno = 0.0
                running_loss_readmit = 0.0
                running_loss_ltm = 0.0
                running_loss = 0.0
                epoch_metrics.add()

        epoch_metrics.write(writer,train_b)


        #scheduler.step()
        evaluate(epoch, val_data_loader, model, 'val', early_stopper, device, train_b)
        #evaluate(epoch, test_data_loader, model, 'test', early_stopper, device, train_b)
        
       
        if early_stopper.early_stop:
            #stop training and eval on test set
            evaluate(epoch, test_data_loader, model, 'test', early_stopper, device, train_b)
            break
Ejemplo n.º 3
0
def evaluate(epoch, data_loader, model, device, train_step=None, split="Test"):
    embedding_layer = nn.Embedding(vectors.shape[0], vectors.shape[1])
    embedding_layer.weight.data.copy_(torch.from_numpy(vectors))
    embedding_layer.weight.requires_grad = False
    '''
    if os.path.exists('./ckpt/multitasking_experiment_ts_text_union_cw.ckpt'):
        model.load_state_dict(torch.load('./ckpt/multitasking_experiment_ts_text_union_cw.ckpt'))
    '''

    epoch_metrics = utils.EpochWriter(split, regression)
    model.to(device)
    model.eval()
    total_data_points = 0

    running_loss = 0.0
    running_loss_decomp = 0.0
    running_loss_los = 0.0
    running_loss_ihm = 0.0
    running_loss_pheno = 0.0
    tk = tqdm(data_loader, total=int(len(data_loader)))
    criterion = nn.BCEWithLogitsLoss()

    for i, data in enumerate(tk):

        # if i>=50:
        #     break
        if data is None:
            continue
        num_valid_data = data['decomp label'].shape[0]
        total_data_points += num_valid_data
        name = data['name'][0]
        # print(name)

        ihm_mask = torch.from_numpy(np.array(data['ihm mask'])).long()
        ihm_mask = ihm_mask.to(device)

        ihm_label = np.array(data['ihm label'])
        ihm_label = ihm_label.reshape(-1, 1).squeeze(1)
        ihm_label = torch.from_numpy(ihm_label).long().to(device)

        decomp_mask = (torch.from_numpy(data['decomp mask'])).long().to(device)
        decomp_label = data['decomp label']
        #num_valid_data = decomp_label.shape[0]
        decomp_label = decomp_label.reshape(-1, 1).squeeze(1)  # (b*t,)
        decomp_label = torch.from_numpy(decomp_label).long().to(device)

        los_mask = torch.from_numpy(np.array(
            data['los mask'])).long().to(device)
        #print('los mask',los_mask)
        #np.save('./res/mask/los_mask_{}.npy'.format(name), los_mask)
        los_label = np.array(data['los label'])
        #np.save('./res/label/los_label_{}.npy'.format(name), los_label)
        los_label = los_label.reshape(-1, 1).squeeze(1)

        los_label = torch.from_numpy(los_label).to(device)
        pheno_label = torch.from_numpy(np.array(
            data['pheno label'])).float().to(device)

        if use_ts:
            ts = torch.from_numpy(data['time series'])
            ts = ts.permute(1, 0, 2).float().to(device)
        else:
            ts = None

        if use_wf:
            waveforms = torch.from_numpy(data['waveforms']).float()
            waveforms_weight_mat = torch.from_numpy(
                data['waveforms weight mat']).float()
            waveforms_weight_mat = waveforms_weight_mat.to(device)
            waveforms = waveforms.to(device)
        else:
            waveforms = None
            waveforms_weight_mat = None

        if use_text:
            texts = torch.from_numpy(data['texts']).to(torch.int64)
            texts_weight_mat = torch.from_numpy(
                data['texts weight mat']).float()
            texts_weight_mat = texts_weight_mat.to(device)

            texts = embedding_layer(
                texts)  # [batch_size, num_docs, seq_len, emb_dim]

            texts = texts.to(device)
            if text_model.name == 'avg':
                texts = avg_emb(texts, texts_weight_mat)
            # t = ts.shape[0]
            # b = ts.shape[1]
            # texts = torch.rand(b*t, 768).float().to(device)
            # texts_weight_mat = None
        else:
            texts = None
            texts_weight_mat = None

        decomp_logits, los_logits, ihm_logits, pheno_logits = model(ts = ts, texts = texts,texts_weight_mat = texts_weight_mat,\
         waveforms = waveforms, waveforms_weight_mat = waveforms_weight_mat)

        loss_decomp = masked_weighted_cross_entropy_loss(
            None, decomp_logits, decomp_label, decomp_mask)
        loss_los = masked_weighted_cross_entropy_loss(None, los_logits,
                                                      los_label, los_mask)
        # loss_los = masked_mse_loss(los_logits,
        #                            los_label,
        #                            los_mask)
        loss_ihm = masked_weighted_cross_entropy_loss(None, ihm_logits,
                                                      ihm_label, ihm_mask)
        loss_pheno = criterion(pheno_logits, pheno_label)

        losses = {
            'decomp': loss_decomp,
            'ihm': loss_ihm,
            'los': loss_los,
            'pheno': loss_pheno
        }
        loss = 0.0

        for task in losses:
            loss += losses[task] * task_weight[task]

        running_loss += loss.item()
        running_loss_decomp += loss_decomp.item() * task_weight['decomp']
        running_loss_los += loss_los.item() * task_weight['los']
        running_loss_ihm += loss_ihm.item() * task_weight['ihm']
        running_loss_pheno += loss_pheno.item() * task_weight['pheno']

        m = nn.Softmax(dim=1)
        sigmoid = nn.Sigmoid()

        decomp_pred = (sigmoid(decomp_logits)[:, 1]).cpu().detach().numpy()
        #los_pred = torch.argmax(m(los_logits), dim=1).cpu().detach().numpy()
        los_pred = m(los_logits).cpu().detach().numpy()
        #np.save('./res/pred/los_pred_{}.npy'.format(name), los_pred)
        #los_pred = los_logits.cpu().detach().numpy()
        ihm_pred = (sigmoid(ihm_logits)[:, 1]).cpu().detach().numpy()
        pheno_pred = sigmoid(pheno_logits).cpu().detach().numpy()
        #print(pheno_pred.sum(dim=1))

        outputs = {
            'decomp': {
                'pred': decomp_pred,
                'label': decomp_label.cpu().detach().numpy(),
                'mask': decomp_mask.cpu().detach().numpy()
            },
            'ihm': {
                'pred': ihm_pred,
                'label': ihm_label.cpu().detach().numpy(),
                'mask': ihm_mask.cpu().detach().numpy()
            },
            'los': {
                'pred': los_pred,
                'label': los_label.cpu().detach().numpy(),
                'mask': los_mask.cpu().detach().numpy()
            },
            'pheno': {
                'pred': pheno_pred,
                'label': pheno_label.cpu().detach().numpy(),
                'mask': None
            },
        }

        epoch_metrics.cache(outputs, num_valid_data)

    #unique_elements, counts_elements = np.unique(metric_los_kappa.y_pred, return_counts=True)
    #labels, counts = np.unique(metric_los_kappa.y_true, return_counts=True)

    if train_step is not None:
        xpoint = train_step
    else:
        xpoint = epoch + 1

    epoch_metrics.write(writer, xpoint)
    writer.add_scalar('{} loss'.format(split), running_loss / (i), xpoint)
    writer.add_scalar('{} decomp loss'.format(split),
                      running_loss_decomp / (i), xpoint)
    writer.add_scalar('{} los loss'.format(split), running_loss_los / (i),
                      xpoint)
    writer.add_scalar('{} ihm loss'.format(split), running_loss_ihm / (i),
                      xpoint)
    writer.add_scalar('{} pheno loss'.format(split), running_loss_pheno / (i),
                      xpoint)
Ejemplo n.º 4
0
def train(epochs, train_data_loader, test_data_loader, model, optimizer,
          device):
    # class frequency for each
    decomp_weight = torch.FloatTensor([1.0214, 47.6688])
    decomp_weight = decomp_weight / decomp_weight.sum()
    ihm_weight = torch.FloatTensor([1.1565, 7.3888])
    ihm_weight = ihm_weight / ihm_weight.sum()
    # los_weight = torch.FloatTensor([ 66.9758,  30.3148,  13.7411,   6.8861,   4.8724,   4.8037,   5.7935,
    #       8.9295,  29.8249, 391.6768])
    los_weight = torch.FloatTensor([1.6047, 3.8934, 8.3376])
    #los_weight = los_weight/los_weight.sum()
    pheno_weight = torch.FloatTensor([
        19.2544, 55.1893, 40.1445, 12.8604, 30.7595, 31.4979, 19.9768, 57.2309,
        15.4088, 12.8200, 43.2644, 21.3991, 14.2026, 9.8531, 15.3284, 57.1641,
        31.0782, 46.4064, 81.0640, 102.7755, 47.5936, 29.6070, 22.7682,
        28.8175, 52.8856
    ])
    pheno_weight = pheno_weight / pheno_weight.sum()
    pheno_weight = pheno_weight.to(device)
    criterion = nn.BCEWithLogitsLoss()

    embedding_layer = nn.Embedding(vectors.shape[0], vectors.shape[1])
    embedding_layer.weight.data.copy_(torch.from_numpy(vectors))
    embedding_layer.weight.requires_grad = False

    model.to(device)
    train_b = 0

    running_loss_decomp = 0.0
    running_loss_los = 0.0
    running_loss_ihm = 0.0
    running_loss_pheno = 0.0
    running_loss = 0.0

    for epoch in range(epochs):
        total_data_points = 0
        print('Epoch {}/{}'.format(epoch + 1, epochs))
        print('-' * 50)
        model.train()

        running_loss = 0.0
        running_loss_decomp = 0.0
        running_loss_los = 0.0
        running_loss_ihm = 0.0
        running_loss_pheno = 0.0

        epoch_metrics = utils.EpochWriter("Train", regression)

        tk0 = tqdm(train_data_loader, total=int(len(train_data_loader)))
        '''
        if os.path.exists('./ckpt/multitasking_experiment_ts_text_union_cw.ckpt'):
            model.load_state_dict(torch.load('./ckpt/multitasking_experiment_ts_text_union_cw.ckpt'))
        '''

        for i, data in enumerate(tk0):
            if data is None:
                continue

            ihm_mask = torch.from_numpy(np.array(data['ihm mask']))
            ihm_mask = ihm_mask.to(device)
            ihm_label = torch.from_numpy(np.array(data['ihm label'])).long()
            ihm_label = ihm_label.reshape(-1, 1).squeeze(1)
            ihm_label = ihm_label.to(device)
            ihm_weight = ihm_weight.to(device)
            decomp_mask = torch.from_numpy(data['decomp mask'])
            decomp_mask = decomp_mask.to(device)
            decomp_label = torch.from_numpy(data['decomp label']).long()
            # the num valid data is used in case the last batch is smaller than batch size
            num_valid_data = decomp_label.shape[0]
            total_data_points += num_valid_data
            decomp_label = decomp_label.reshape(-1, 1).squeeze(1)  # (b*t,)
            decomp_label = decomp_label.to(device)
            decomp_weight = decomp_weight.to(device)

            los_mask = torch.from_numpy(np.array(data['los mask']))
            los_mask = los_mask.to(device)
            los_label = torch.from_numpy(np.array(data['los label']))

            los_label = los_label.reshape(-1, 1).squeeze(1)
            los_label = los_label.to(device)
            los_weight = los_weight.to(device)
            pheno_label = torch.from_numpy(np.array(
                data['pheno label'])).float()
            pheno_label = pheno_label.to(device)

            if use_ts:
                ts = torch.from_numpy(data['time series'])
                ts = ts.permute(1, 0, 2).float().to(device)
            else:
                ts = None
            if use_wf:
                waveforms = torch.from_numpy(data['waveforms']).float()
                waveforms_weight_mat = torch.from_numpy(
                    data['waveforms weight mat']).float()
                waveforms_weight_mat = waveforms_weight_mat.to(device)
                waveforms = waveforms.to(device)
            else:
                waveforms = None
                waveforms_weight_mat = None

            if use_text:
                texts = torch.from_numpy(data['texts']).to(torch.int64)
                texts_weight_mat = torch.from_numpy(
                    data['texts weight mat']).float()
                texts_weight_mat = texts_weight_mat.to(device)

                texts = embedding_layer(
                    texts)  # [batch_size, num_docs, seq_len, emb_dim]
                #print(texts.shape)
                texts = texts.to(device)
                #print(texts.shape)
                if text_model.name == 'avg':
                    texts = avg_emb(texts, texts_weight_mat)
                # t = ts.shape[0]
                # b = ts.shape[1]
                # texts = torch.rand(b*t, 768).float().to(device)
                # texts_weight_mat = None
            else:
                texts = None
                texts_weight_mat = None


            decomp_logits, los_logits, ihm_logits, pheno_logits = model(ts = ts, texts = texts,\
             texts_weight_mat = texts_weight_mat, waveforms = waveforms, waveforms_weight_mat =waveforms_weight_mat)
            loss_decomp = masked_weighted_cross_entropy_loss(
                None, decomp_logits, decomp_label, decomp_mask)
            #loss_los = masked_weighted_cross_entropy_loss(None, los_logits,los_label, los_mask)
            loss_los = masked_weighted_cross_entropy_loss(
                los_weight, los_logits, los_label, los_mask)
            #print(loss_los.item())
            loss_ihm = masked_weighted_cross_entropy_loss(
                None, ihm_logits, ihm_label, ihm_mask)
            loss_pheno = criterion(pheno_logits, pheno_label)

            losses = {
                'decomp': loss_decomp,
                'ihm': loss_ihm,
                'los': loss_los,
                'pheno': loss_pheno
            }
            loss = 0.0

            for task in losses:
                #loss += losses[task]

                prec = torch.exp(-log_var[task])
                #losses[task] = torch.sum(losses[task] * prec + log_var[task], -1)

                loss += torch.sum(losses[task] * prec + log_var[task], -1)
                #loss += losses[task] * task_weight[task]

                #loss += losses[task]
            #loss = torch.mean(loss)

            #loss = loss_decomp*5+loss_ihm*2+loss_los*1+loss_pheno*2
            #loss = loss_los
            train_b += 1

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            running_loss_decomp += loss_decomp.item() * task_weight['decomp']
            running_loss_los += loss_los.item() * task_weight['los']
            running_loss_ihm += loss_ihm.item() * task_weight['ihm']
            running_loss_pheno += loss_pheno.item() * task_weight['pheno']

            m = nn.Softmax(dim=1)
            sig = nn.Sigmoid()

            decomp_pred = (m(decomp_logits)[:, 1]).cpu().detach().numpy()
            #los_pred = torch.argmax(m(los_logits), dim=1).cpu().detach().numpy()
            los_pred = m(los_logits).cpu().detach().numpy()
            #los_pred = los_logits.cpu().detach().numpy()
            ihm_pred = (m(ihm_logits)[:, 1]).cpu().detach().numpy()
            pheno_pred = sig(pheno_logits).cpu().detach().numpy()

            outputs = {
                'decomp': {
                    'pred': decomp_pred,
                    'label': decomp_label.cpu().detach().numpy(),
                    'mask': decomp_mask.cpu().detach().numpy()
                },
                'ihm': {
                    'pred': ihm_pred,
                    'label': ihm_label.cpu().detach().numpy(),
                    'mask': ihm_mask.cpu().detach().numpy()
                },
                'los': {
                    'pred': los_pred,
                    'label': los_label.cpu().detach().numpy(),
                    'mask': los_mask.cpu().detach().numpy()
                },
                'pheno': {
                    'pred': pheno_pred,
                    'label': pheno_label.cpu().detach().numpy(),
                    'mask': None
                },
            }
            epoch_metrics.cache(outputs, num_valid_data)
            interval = 500

            if i % interval == interval - 1:
                writer.add_scalar('training loss',
                                  running_loss / (interval - 1), train_b)
                writer.add_scalar('decomp loss',
                                  running_loss_decomp / (interval - 1),
                                  train_b)
                writer.add_scalar('los loss',
                                  running_loss_los / (interval - 1), train_b)
                writer.add_scalar('ihm loss',
                                  running_loss_ihm / (interval - 1), train_b)
                writer.add_scalar('pheno loss',
                                  running_loss_pheno / (interval - 1), train_b)
                '''
                print('loss decomp', running_loss_decomp/interval)
                print('loss ihm', running_loss_ihm/interval)
                print('loss los', running_loss_los/interval)
                print('loss pheno', running_loss_pheno/interval)
                print('epoch {} , training loss is {:.3f}'.format(epoch+1, running_loss_los/interval))
                '''

                running_loss_decomp = 0.0
                running_loss_los = 0.0
                running_loss_ihm = 0.0
                running_loss_pheno = 0.0
                running_loss = 0.0
                epoch_metrics.add()
        for task in losses:
            print(task, torch.exp(-log_var[task]))
        epoch_metrics.write(writer, train_b)

        torch.save(
            model.state_dict(),
            os.path.join('./ckpt/', 'epoch{0}'.format(epoch) + model_weights))

        evaluate(epoch, val_data_loader, model, device, train_b, "Val")
        evaluate(epoch, test_data_loader, model, device, train_b, "Test")
Ejemplo n.º 5
0
def evaluate(epoch,
             data_loader,
             model,
             split,
             early_stopper,
             device,
             train_step=None):

    aucroc_readmit = utils.AUCROCREADMIT()
    aucpr_readmit = utils.AUCPRREADMIT()
    cfm_readmit = utils.ConfusionMatrixReadmit()

    if split == 'val':
        epoch_metrics = utils.EpochWriter("Val", regression, experiment)

    else:
        epoch_metrics = utils.EpochWriter("Test", regression, experiment)

    model.to(device)
    model.eval()

    running_loss = 0.0

    tk = tqdm(data_loader, total=int(len(data_loader)))
    criterion = nn.BCEWithLogitsLoss()

    for i, data in enumerate(tk):
        if data is None:
            continue


        decomp_label, decomp_mask, los_label, los_mask, ihm_label, ihm_mask,\
                 pheno_label, readmit_label, readmit_mask, num_valid_data = retrieve_data(data, device)

        if use_ts:
            ts = torch.from_numpy(data['time series'])
            ts = ts.permute(1, 0, 2).float().to(device)
        else:
            ts = None

        if use_text:
            texts = text_embedding(embedding_layer, data, device)
        else:
            texts = None

        readmit_logits = model(texts=texts)

        loss = masked_weighted_cross_entropy_loss(None, readmit_logits,
                                                  readmit_label, readmit_mask)

        running_loss += loss.item()

        sigmoid = nn.Sigmoid()
        readmit_pred = (sigmoid(readmit_logits)[:, 1]).cpu().detach().numpy()
        readmit_label = readmit_label.cpu().detach().numpy()
        readmit_mask = readmit_mask.cpu().detach().numpy()

        aucpr_readmit.add(readmit_pred, readmit_label, readmit_mask)
        aucroc_readmit.add(readmit_pred, readmit_label, readmit_mask)
        cfm_readmit.add(readmit_pred, readmit_label, readmit_mask)

    print('readmission aucpr is {}'.format(aucpr_readmit.get()))
    print('readmission aucroc is {}'.format(aucroc_readmit.get()))
    print('readmission cfm is {}'.format(cfm_readmit.get()))

    if train_step is not None:
        xpoint = train_step
    else:
        xpoint = epoch + 1

    writer.add_scalar('{} readmit loss'.format(split), running_loss / (i),
                      xpoint)

    if split == 'val':
        early_stopper(running_loss_pheno / (i), model)
Ejemplo n.º 6
0
def train(epochs, train_data_loader, test_data_loader, early_stopper, model,
          optimizer, scheduler, device):

    criterion = nn.BCEWithLogitsLoss()
    aucroc_readmit = utils.AUCROCREADMIT()
    aucpr_readmit = utils.AUCPRREADMIT()
    cfm_readmit = utils.ConfusionMatrixReadmit()

    model.to(device)
    train_b = 0

    for epoch in range(epochs):

        print('Epoch {}/{}'.format(epoch + 1, epochs))
        print('-' * 50)
        model.train()

        running_loss = 0.0

        epoch_metrics = utils.EpochWriter("Train", regression, experiment)

        tk0 = tqdm(train_data_loader, total=int(len(train_data_loader)))

        for i, data in enumerate(tk0):
            if data is None:
                continue
            decomp_label, decomp_mask, los_label, los_mask, ihm_label, ihm_mask,\
                 pheno_label, readmit_label, readmit_mask, num_valid_data = retrieve_data(data, device)

            if use_text:
                texts = text_embedding(embedding_layer, data, device)
            else:
                texts = None

            readmit_logits = model(texts=texts)

            loss = masked_weighted_cross_entropy_loss(None, readmit_logits,
                                                      readmit_label,
                                                      readmit_mask)

            train_b += 1

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            m = nn.Softmax(dim=1)
            sig = nn.Sigmoid()

            readmit_pred = (sig(readmit_logits)[:, 1]).cpu().detach().numpy()
            readmit_label = readmit_label.cpu().detach().numpy()
            if readmit_label is None:
                print('bad')
            readmit_mask = readmit_mask.cpu().detach().numpy()
            aucpr_readmit.add(readmit_pred, readmit_label, readmit_mask)
            aucroc_readmit.add(readmit_pred, readmit_label, readmit_mask)
            cfm_readmit.add(readmit_pred, readmit_label, readmit_mask)
            interval = 50

            if i % interval == interval - 1:
                writer.add_scalar('training loss',
                                  running_loss / (interval - 1), train_b)

        print('readmission aucpr is {}'.format(aucpr_readmit.get()))
        print('readmission aucroc is {}'.format(aucroc_readmit.get()))
        print('readmission cfm is {}'.format(cfm_readmit.get()))

        #scheduler.step()
        #evaluate(epoch, val_data_loader, model, 'val', early_stopper, device, train_b)
        evaluate(epoch, test_data_loader, model, 'test', early_stopper, device,
                 train_b)