Example #1
0
    def train_epoch(self, train_dataloader, optimizer, device):
        train_loss, train_acc, losses = [], [], []
        total_train_accuracy = 0
        total_train_loss = 0
        tp_tot, fp_tot, tn_tot, fn_tot = 0, 0, 0, 0
        self.model.train()
        pbar_file = sys.stdout
        pbar_name = "train_batch"
        num_batches = len(train_dataloader.batch_sampler)
        with tqdm.tqdm(desc=pbar_name, total=num_batches,
                       file=pbar_file) as pbar:
            for step, batch in enumerate(train_dataloader):
                b_input_ids = batch[0].to(device)
                b_input_mask = batch[1].to(device)
                b_labels = batch[2].to(device)
                self.model.zero_grad()

                # Forward pass
                output = self.model(b_input_ids,
                                    token_type_ids=None,
                                    attention_mask=b_input_mask,
                                    labels=b_labels)
                # Log the train loss
                loss = output.loss
                logits = output.logits
                total_train_loss += loss.item()

                # Backward pass
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)

                # Weight updates
                optimizer.step()

                logits = logits.detach().cpu()
                y = b_labels.to('cpu').squeeze(1)
                y_pred = torch.argmax(logits, dim=1)  #.unsqueeze(1)
                tp, fp, tn, fn = calculate_acc(y_pred, y)
                tp_tot += tp
                fp_tot += fp
                tn_tot += tn
                fn_tot += fn
                total_train_accuracy += torch.sum(y_pred == y).float().item()
                pbar.set_description(f'{pbar_name} ({loss.item():.3f})')
                pbar.update()
        avg_train_accuracy = (total_train_accuracy /
                              len(train_dataloader.dataset)) * 100
        #print("  Training accuracy: {0:.2f}".format(avg_train_accuracy))
        print(
            f"accuracy={avg_train_accuracy:.3f}, tp: {tp_tot}, fp: {fp_tot}, tn: {tn_tot}, fn: {fn_tot}"
        )
        if tp_tot + fn_tot > 0:
            print(
                f"Pos acc: {tp_tot / (tp_tot + fn_tot):.3f},  Neg acc: {tn_tot / (tn_tot + fp_tot):.3f}"
            )
        avg_train_loss = total_train_loss / len(train_dataloader)
        # Log the Avg. train loss
        print("  Training loss: {0:.4f}".format(avg_train_loss))
        return EpochResult(avg_train_loss, avg_train_accuracy)
Example #2
0
    def train_epoch(self, train_dataloader, optimizer, loss_fn, device):
        total_train_accuracy = 0
        total_train_loss = 0
        tp_tot, fp_tot, tn_tot, fn_tot = 0, 0, 0, 0
        pbar_file = sys.stdout
        pbar_name = "train_batch"
        num_batches = len(train_dataloader.batch_sampler)
        with tqdm.tqdm(desc=pbar_name, total=num_batches,
                       file=pbar_file) as pbar:
            for step, batch in enumerate(train_dataloader):
                X, y = batch[0], batch[1]
                if y.shape[0] == 1:
                    continue

                # Forward pass
                # X = torch.transpose(X, dim0=0, dim1=1)
                X = X.to(device)
                y = y.to(device)
                # self.lstm = self.lstm.float()
                y_pred_log_proba = self.forward(X)
                y = torch.squeeze(y).long()  # should be of size (N,)

                # Backward pass
                optimizer.zero_grad()
                loss = loss_fn(y_pred_log_proba, y)
                loss.backward()

                # Weight updates
                optimizer.step()

                # Calculate accuracy
                total_train_loss += loss.item()
                y_pred = torch.argmax(y_pred_log_proba, dim=1)
                tp, fp, tn, fn = calculate_acc(y_pred, y)
                tp_tot += tp
                fp_tot += fp
                tn_tot += tn
                fn_tot += fn
                total_train_accuracy += torch.sum(y_pred == y).float().item()
                pbar.set_description(f'{pbar_name} ({loss.item():.3f})')
                pbar.update()
        avg_train_accuracy = (total_train_accuracy /
                              len(train_dataloader.dataset)) * 100
        # print("  Training accuracy: {0:.2f}".format(avg_train_accuracy))
        print(
            f"  accuracy={avg_train_accuracy:.3f}, tp: {tp_tot}, fp: {fp_tot}, tn: {tn_tot}, fn: {fn_tot}"
        )
        # if tp_tot + fn_tot > 0:
        #     print(f"Pos acc: {tp_tot / (tp_tot + fn_tot):.3f},  Neg acc: {tn_tot / (tn_tot + fp_tot):.3f}")
        avg_train_loss = total_train_loss / len(train_dataloader)
        # Log the Avg. train loss
        print("  Training loss: {0:.4f}".format(avg_train_loss))
        return EpochResult(avg_train_loss, avg_train_accuracy)
Example #3
0
    def test_epoch(self, test_dataloader, device):
        self.model.eval()
        total_eval_accuracy = 0
        total_eval_loss = 0
        tp_tot, fp_tot, tn_tot, fn_tot = 0, 0, 0, 0
        # Evaluate data for one epoch
        pbar_file = sys.stdout
        pbar_name = "test_batch"
        num_batches = len(test_dataloader.batch_sampler)
        with tqdm.tqdm(desc=pbar_name, total=num_batches,
                       file=pbar_file) as pbar:
            for batch in test_dataloader:
                b_input_ids = batch[0].to(device)
                b_input_mask = batch[1].to(device)
                b_labels = batch[2].to(device)
                with torch.no_grad():
                    output = self.model(b_input_ids,
                                        token_type_ids=None,
                                        attention_mask=b_input_mask,
                                        labels=b_labels)
                    loss = output.loss
                    logits = output.logits
                    total_eval_loss += loss.item()
                    logits = logits.detach().cpu()
                    y = b_labels.to('cpu').squeeze(1)
                    y_pred = torch.argmax(logits, dim=1)
                    tp, fp, tn, fn = calculate_acc(y_pred, y)
                    tp_tot += tp
                    fp_tot += fp
                    tn_tot += tn
                    fn_tot += fn
                total_eval_accuracy += torch.sum(y_pred == y).float().item()
                pbar.set_description(f'{pbar_name} ({loss.item():.3f})')
                pbar.update()

        avg_val_accuracy = (total_eval_accuracy /
                            len(test_dataloader.dataset)) * 100
        # print("  Validation accuracy: {0:.2f}".format(avg_val_accuracy))
        print(
            f"accuracy={avg_val_accuracy:.3f}, tp: {tp_tot}, fp: {fp_tot}, tn: {tn_tot}, fn: {fn_tot}"
        )
        if tp_tot + fn_tot > 0:
            print(
                f"Pos acc: {tp_tot / (tp_tot + fn_tot):.3f},  Neg acc: {tn_tot / (tn_tot + fp_tot):.3f}"
            )
        avg_val_loss = total_eval_loss / len(test_dataloader)
        # Log the Avg. validation accuracy
        print("  Validation Loss: {0:.4f}".format(avg_val_loss))
        return EpochResult(avg_val_loss, avg_val_accuracy)
Example #4
0
    def test_epoch(self, test_dataloader, loss_fn, device):
        total_eval_accuracy = 0
        total_eval_loss = 0
        tp_tot, fp_tot, tn_tot, fn_tot = 0, 0, 0, 0
        # Evaluate data for one epoch
        pbar_file = sys.stdout
        pbar_name = "test_batch"
        num_batches = len(test_dataloader.batch_sampler)
        with tqdm.tqdm(desc=pbar_name, total=num_batches,
                       file=pbar_file) as pbar:
            for batch in test_dataloader:
                X, y = batch[0], batch[1]
                if y.shape[0] == 1:
                    continue

                # Forward pass
                with torch.no_grad():
                    X = X.to(device)
                    y = y.to(device)
                    y_pred_log_proba = self.forward(X)

                    y = torch.squeeze(y).long()
                    loss = loss_fn(y_pred_log_proba, y)
                    total_eval_loss += loss.item()
                    y_pred = torch.argmax(y_pred_log_proba, dim=1)
                    tp, fp, tn, fn = calculate_acc(y_pred, y)
                    tp_tot += tp
                    fp_tot += fp
                    tn_tot += tn
                    fn_tot += fn

                total_eval_accuracy += torch.sum(y_pred == y).float().item()
                pbar.set_description(f'{pbar_name} ({loss.item():.3f})')
                pbar.update()

        avg_val_accuracy = (total_eval_accuracy /
                            len(test_dataloader.dataset)) * 100
        print(
            f"  accuracy={avg_val_accuracy:.3f}, tp: {tp_tot}, fp: {fp_tot}, tn: {tn_tot}, fn: {fn_tot}"
        )
        # if tp_tot + fn_tot > 0:
        #     print(f"Pos acc: {tp_tot / (tp_tot + fn_tot):.3f},  Neg acc: {tn_tot / (tn_tot + fp_tot):.3f}")
        avg_val_loss = total_eval_loss / len(test_dataloader)
        # Log the Avg. validation accuracy
        print("  Validation Loss: {0:.4f}".format(avg_val_loss))
        return EpochResult(avg_val_loss, avg_val_accuracy)
Example #5
0
    def _foreach_batch(dl: DataLoader,
                       forward_fn: Callable[[Any], BatchResult],
                       verbose=True,
                       max_batches=None) -> EpochResult:
        """
        Evaluates the given forward-function on batches from the given
        dataloader, and prints progress along the way.
        """
        losses = []
        num_correct = 0
        num_samples = len(dl.sampler)
        num_batches = len(dl.batch_sampler)

        if max_batches is not None:
            if max_batches < num_batches:
                num_batches = max_batches
                num_samples = num_batches * dl.batch_size

        if verbose:
            pbar_file = sys.stdout
        else:
            pbar_file = open(os.devnull, 'w')

        pbar_name = forward_fn.__name__
        with tqdm.tqdm(desc=pbar_name, total=num_batches,
                       file=pbar_file) as pbar:
            dl_iter = iter(dl)
            for batch_idx in range(num_batches):
                data = next(dl_iter)
                batch_res = forward_fn(data)

                pbar.set_description(f'{pbar_name} ({batch_res.loss:.3f})')
                pbar.update()

                losses.append(batch_res.loss)
                num_correct += batch_res.num_correct

            avg_loss = sum(losses) / num_batches
            accuracy = 100. * num_correct / num_samples
            pbar.set_description(f'{pbar_name} '
                                 f'(Avg. Loss {avg_loss:.3f}, '
                                 f'Accuracy {accuracy:.1f})')

        return EpochResult(losses=losses, accuracy=accuracy)