예제 #1
0
def train_model(model, optimizer, loss_fn, dataloader, device, use_bert):
    """Train model
    Args:
        model: Model either LSTM, LSTMAttention, CNN, MLP (torch.nn.Module)
        optimizer: Optimizer for parameters of the model (torch.optim)
        loss_fn: Loss function taht computs the loss for each batch based on the y_pred and y_target
        dataloader: Dataloader that generates batches of data and labels or in case of BERT input_ids, input_mask, segment_ids and label_ids
        device: Device run either on GPU or CPU
    """

    #Metrics
    epoch_loss = 0
    epoch_accuracy = 0
    epoch_recall = [0, 0]
    epoch_precision = [0, 0]
    epoch_f1 = [0, 0]

    #Set model in training mode
    model.train()

    for step, batch in enumerate(tqdm(dataloader, desc="Iteration")):

        #Step 0: Get batch
        batch = tuple(t.to(device) for t in batch)

        if use_bert:
            input_ids, input_mask, segment_ids, label_ids = batch
        else:
            X, y_target = batch
            y_target = torch.autograd.Variable(y_target).long()

        #Step 1: Clear the gradients
        optimizer.zero_grad()

        #Step 2: Compute the forward pass of the model (model output)
        if use_bert:
            y_pred = model(input_ids, segment_ids, input_mask, labels=None)
            y_target = label_ids
        else:
            y_pred = model(X)

        #Step 3: Compute the loss
        loss = loss_fn(y_pred, y_target)
        loss_batch = loss.item()
        epoch_loss += loss_batch

        #Step 4: Propagate the loss backward
        loss.backward()

        #Step 5: Use optimizer to take gradient step
        optimizer.step()

        #Compute other metrics
        accuracy, recall, precision, f1 = accuracy_recall_precision_f1(
            y_pred, y_target)

        epoch_accuracy += accuracy
        epoch_recall += recall
        epoch_precision += precision
        epoch_f1 += f1

    #Train results
    results = {
        'loss': np.round(epoch_loss / len(dataloader), 2),
        'accuracy': np.round(float(epoch_accuracy / len(dataloader)), 2),
        'recall': np.round(epoch_recall / len(dataloader), 2),
        'precision': np.round(epoch_precision / len(dataloader), 2),
        'f1': np.round(epoch_f1 / len(dataloader), 2)
    }

    return results
예제 #2
0
    def validate(self, val_loader, epoch=0):
        self.model.eval()
        val_loss = 0.0
        total_acc = 0.0
        total_recall = 0.0
        total_precision = 0.0
        total_f1 = 0.0
        total_cm = 0
        total_d_acc = 0.0
        bleu = 0.0
        total_l1 = 0
        total_l2 = 0
        total_l3 = 0

        k_vals = [1, 2, 3, 4, 5]
        total_topk = {k: 0.0 for k in k_vals}
        per_disease_topk = defaultdict(lambda: {str(k): 0.0 for k in k_vals})
        per_disease_bleu = defaultdict(list)
        with torch.no_grad():
            for i, (_, images, labels, f_labels,
                    text) in enumerate(val_loader):
                batch_size = images.size(0)
                images = images.to(self.device)
                labels = labels.to(self.device)
                f_labels = f_labels.to(self.device)
                text = text.to(self.device)
                diseases, fine_diseases, text_pred = self.model(images, text)
                loss1 = self.criterion(diseases, labels)
                loss2 = self.criterion(fine_diseases, f_labels)
                text_loss = 0.0
                for k in range(text_pred.size(1)):
                    text_loss += self.criterion(text_pred[:, k].squeeze(),
                                                text[:, k + 1].squeeze())

                val_loss += torch.stack(
                    (loss1, loss2, text_loss))[self.tasks].sum()

                preds = F.log_softmax(fine_diseases, dim=-1)
                pred = preds.argmax(dim=-1)
                d_pred = F.log_softmax(diseases, dim=-1).argmax(dim=-1)

                # Evaluation of P, R, F1, CM, BLEU
                total_acc += (pred.eq(f_labels).sum().item() / batch_size)
                total_d_acc += (d_pred.eq(labels).sum().item() / batch_size)
                acc, recall, precision, f1 = accuracy_recall_precision_f1(
                    d_pred, labels)
                cm = calculate_confusion_matrix(d_pred, labels)
                try:
                    total_cm += (cm / batch_size)
                except:
                    print("Error occured for this CM")
                    print(cm / batch_size)

                # Top-k evaluation
                for k in k_vals:
                    total_topk[k] += compute_topk(preds, f_labels, k)
                    for d in [0, 1, 2, 3]:
                        mask = labels.eq(d)
                        if mask.sum() > 0:
                            per_disease_topk[d][str(k)] += compute_topk(
                                preds[mask], f_labels[mask], k)

                total_recall += np.mean(recall)
                total_precision += np.mean(precision)
                total_f1 += np.mean(f1)
                preds = torch.argmax(F.log_softmax(text_pred, dim=-1), dim=-1)
                text1 = text[:, 1:].squeeze().tolist()
                preds1 = preds.tolist()
                t_bleu, sent_gt, sent_pred = compute_bleu(
                    self.lang, text1, preds1, labels, per_disease_bleu)

                # Book-keeping
                bleu += t_bleu
                total_l1 += loss1.item()
                total_l2 += loss2.item()
                total_l3 += text_loss.item()
        bleu = bleu / (len(val_loader))
        val_loss = val_loss / len(val_loader)
        total_l1 /= len(val_loader)
        total_l2 /= len(val_loader)
        total_l3 /= len(val_loader)
        total_acc = total_acc / len(val_loader)
        total_d_acc = total_d_acc / len(val_loader)
        total_f1 = total_f1 / len(val_loader)
        total_precision = total_precision / len(val_loader)
        total_recall = total_recall / len(val_loader)
        total_cm = total_cm / len(val_loader)

        self.scheduler.step(val_loss)
        if val_loss <= self.min_val_loss:
            torch.save(self.model.state_dict(), self.save_path)
            self.min_val_loss = val_loss

        disease_f1 = {}
        disease_precision = {}
        disease_recall = {}

        #for i in range(len(total_f1)):
        #   disease_f1[i] = total_f1[i]
        #   disease_precision[i] = total_precision[i]
        #   disease_recall[i] = total_recall[i]
        for d in per_disease_bleu:
            per_disease_bleu[d] = np.mean(per_disease_bleu[d])

        total_topk = {str(k): total_topk[k] / len(val_loader) for k in k_vals}
        for d in [0, 1, 2, 3]:
            for k in k_vals:
                per_disease_topk[d][str(
                    k)] = per_disease_topk[d][str(k)] / len(val_loader)

        return (val_loss, total_d_acc, total_acc, bleu, total_f1, total_recall,
                total_precision, sent_gt, sent_pred, total_topk,
                per_disease_topk, per_disease_bleu, total_cm)
예제 #3
0
    def validate(self, val_loader, epoch=0):
        self.model.eval()
        val_loss = 0.0
        total_acc = 0.0
        total_recall = 0.0
        total_precision = 0.0
        total_f1 = 0.0
        total_cm = 0
        total_d_acc = 0.0
        bleu = 0.0
        total_l1 = 0
        total_l2 = 0
        total_l3 = 0

        k_vals = [1, 2, 3, 4, 5]
        total_topk = {k: 0.0 for k in k_vals}
        per_disease_topk = defaultdict(lambda: {str(k): 0.0 for k in k_vals})
        losses = []
        with torch.no_grad():
            for i, (images, labels) in enumerate(val_loader):
                batch_size = images.size(0)
                images = images.to(self.device)
                labels = labels.to(self.device)
                teacher_scores = self.kd_model(images)
                val, pred = F.softmax(teacher_scores, dim=-1).max(dim=-1)

                index = val >= self.threshold
                if index.any().item():
                    images = images[index]
                    labels = labels[index]
                    teacher_scores = teacher_scores[index]
                    diseases = self.model.module(images)

                    loss1, _, _ = self.criterion(diseases, labels,
                                                 teacher_scores)

                    val_loss += loss1.item()

                    # Evaluation of P, R, F1, BLEU
                    d_pred = F.log_softmax(diseases, dim=-1).argmax(dim=-1)
                    total_d_acc += (d_pred.eq(labels).sum().item() /
                                    batch_size)
                    acc, recall, precision, f1 = accuracy_recall_precision_f1(
                        d_pred, labels)

                    total_recall += np.mean(recall)
                    total_precision += np.mean(precision)
                    total_f1 += np.mean(f1)

                    cm = calculate_confusion_matrix(d_pred, labels)
                    try:
                        total_cm += (cm / batch_size)
                    except:
                        print("error occured for this CM")
                        print(cm / batch_size)
        val_loss = val_loss / len(val_loader)
        total_d_acc = total_d_acc / len(val_loader)
        total_f1 = total_f1 / len(val_loader)
        total_precision = total_precision / len(val_loader)
        total_recall = total_recall / len(val_loader)
        total_cm = total_cm / len(val_loader)

        self.scheduler.step(val_loss)
        if val_loss <= self.min_val_loss:
            torch.save(self.model.state_dict(), self.save_path)
            self.min_val_loss = val_loss

        disease_f1 = {}
        disease_precision = {}
        disease_recall = {}

        # for i in range(len(total_f1)):
        #   disease_f1[i] = total_f1[i]
        #   disease_precision[i] = total_precision[i]
        #   disease_recall[i] = total_recall[i]

        return (val_loss, total_d_acc, total_f1, total_recall, total_precision,
                total_cm)
def evaluate_model_features(model, optimizer, loss_fn, dataloader, device, use_bert):
    """Evaluate model
    Args:
        model: Model either LSTM, LSTMAttention, CNN, MLP (torch.nn.Module)
        optimizer: Optimizer for parameters of the model (torch.optim)
        loss_fn: Loss function taht computs the loss for each batch based on the y_pred and y_target
        dataloader: Dataloader that generates batches of data and labels or in case of BERT input_ids, input_mask, segment_ids and label_ids
        device: Device run either on GPU or CPU
    """

    #Metrics
    epoch_loss = 0
    epoch_accuracy = 0
    epoch_recall = [0, 0]
    epoch_precision = [0, 0]
    epoch_f1 = [0, 0]
    cm = np.zeros((2,2))

    #Set model in evaluate mode
    model.eval()

    with torch.no_grad():

        for step, batch in enumerate(tqdm(dataloader, desc="Iteration")):

            #Step 0: Get batch
            batch = tuple(t.to(device) for t in batch)

            if use_bert:
                input_ids, input_mask, segment_ids, label_ids = batch
            else:
                X, features, y_target = batch
                y_target = torch.autograd.Variable(y_target).long()

            #Step 1: Compute the forward pass of the model (model output)
            if use_bert:
                y_pred = model(input_ids, segment_ids, input_mask, labels=None)
                y_target = label_ids
            else:
                y_pred = model(X, features)

            #Step 2: Compute the loss
            loss = loss_fn(y_pred, y_target)
            loss_batch = loss.item()
            epoch_loss += loss_batch

            #Compute other metrics
            accuracy, recall, precision, f1 = accuracy_recall_precision_f1(y_pred, y_target)

            epoch_accuracy += accuracy
            epoch_recall += recall
            epoch_precision += precision
            epoch_f1 += f1

            #Compute confusion metrics
            cm += calculate_confusion_matrix(y_pred, y_target)

        #Evaluation results
        results = {
            'loss': np.round(epoch_loss / len(dataloader),2),
            'accuracy': np.round(float(epoch_accuracy / len(dataloader)),2),
            'recall': np.round(epoch_recall / len(dataloader), 2),
            'precision': np.round(epoch_precision / len(dataloader), 2),
            'f1': np.round(epoch_f1 / len(dataloader), 2),
            'cm': cm
        }

    return results