def multi_model_valid(ensemble_model, model_lst, dataloader, criterion,
                      device):
    ensemble_model.eval()
    epoch_start = time.time()
    running_loss = 0.0
    running_accuracy = 0.0
    with torch.no_grad():
        for batch in dataloader:
            premises = batch["premise"].to(device)
            premises_lengths = batch["premise_length"].to(device)
            hypotheses = batch["hypothesis"].to(device)
            hypotheses_lengths = batch["hypothesis_length"].to(device)
            labels = batch["label"].to(device)
            logits_probs_lst = [
                model(premises, premises_lengths, hypotheses,
                      hypotheses_lengths) for model in model_lst
            ]
            logits_lst = [i[0].unsqueeze(1) for i in logits_probs_lst]
            probs_lst = [i[1].unsqueeze(1) for i in logits_probs_lst]
            logits, probs = ensemble_model(logits_lst, probs_lst)
            loss = criterion(logits, labels)
            running_loss += loss.item()
            running_accuracy += correct_predictions(probs, labels)
    epoch_time = time.time() - epoch_start
    epoch_loss = running_loss / len(dataloader)
    epoch_accuracy = running_accuracy / (len(dataloader.dataset))
    return epoch_time, epoch_loss, epoch_accuracy
Beispiel #2
0
def multi_model_train(args,
          epoch,
          ensemble_model,
          model_lst,
          dataloader,
          optimizer,
          criterion,
          max_gradient_norm,
          device):
    ensemble_model.train()
    epoch_start = time.time()
    batch_time_avg = 0.0
    running_loss = 0.0
    correct_preds = 0
    tqdm_batch_iterator = tqdm(dataloader)
    num_labels = 0
    for batch_index, batch in enumerate(tqdm_batch_iterator):
        if args.local_rank != -1:
            dataloader.sampler.set_epoch(epoch)
        batch_start = time.time()
        premises = batch["premise"].to(device)
        premises_lengths = batch["premise_length"].to(device)
        hypotheses = batch["hypothesis"].to(device)
        hypotheses_lengths = batch["hypothesis_length"].to(device)
        labels = batch["label"].to(device)
        num_labels += len(labels)
        optimizer.zero_grad()
        logits_probs_lst = [model(premises,
                                  premises_lengths,
                                  hypotheses,
                                  hypotheses_lengths) for model in model_lst]
        logits_lst = [i[0].unsqueeze(1) for i in logits_probs_lst]
        probs_lst = [i[1].unsqueeze(1) for i in logits_probs_lst]
        logits, probs = ensemble_model(logits_lst, probs_lst)
        loss = criterion(logits, labels)
        loss.backward()
        nn.utils.clip_grad_norm_(ensemble_model.parameters(), max_gradient_norm)
        optimizer.step()
        batch_time_avg += time.time() - batch_start
        running_loss += loss.item()
        correct_preds += correct_predictions(probs, labels)
        description = "Avg. batch proc. time: {:.4f}s, loss: {:.4f}"\
                      .format(batch_time_avg/(batch_index+1),
                              running_loss/(batch_index+1))
        tqdm_batch_iterator.set_description(description)
    epoch_time = time.time() - epoch_start
    epoch_loss = running_loss / len(dataloader)
    epoch_accuracy = correct_preds / num_labels
    return epoch_time, epoch_loss, epoch_accuracy
Beispiel #3
0
def test(model, dataloader):
    """
    Test the accuracy of a model on some labelled test dataset.
    Args:
        model: The torch module on which testing must be performed.
        dataloader: A DataLoader object to iterate over some dataset.
    Returns:
        batch_time: The average time to predict the classes of a batch.
        total_time: The total time to process the whole dataset.
        accuracy: The accuracy of the model on the input data.
    """
    # Switch the model to eval mode.
    model.eval()
    device = model.device

    time_start = time.time()
    batch_time = 0.0
    accuracy = 0.0

    # Deactivate autograd for evaluation.
    with torch.no_grad():
        for batch in dataloader:
            batch_start = time.time()

            # Move input and output data to the GPU if one is used.
            q1 = batch["q1"].to(device)
            q1_lengths = batch["q1_length"].to(device)
            q2 = batch["q2"].to(device)
            q2_lengths = batch["q2_length"].to(device)
            labels = batch["label"].to(device)

            _, probs = model(q1, q1_lengths, q2, q2_lengths)

            accuracy += correct_predictions(probs, labels)
            batch_time += time.time() - batch_start

    batch_time /= len(dataloader)
    total_time = time.time() - time_start
    accuracy /= (len(dataloader.dataset))

    return batch_time, total_time, accuracy
def validate(model, dataloader, criterion, device):
    model.eval()
    epoch_start = time.time()
    running_loss = 0.0
    running_accuracy = 0.0
    with torch.no_grad():
        for batch in dataloader:
            premises = batch["premise"].to(device)
            premises_lengths = batch["premise_length"].to(device)
            hypotheses = batch["hypothesis"].to(device)
            hypotheses_lengths = batch["hypothesis_length"].to(device)
            labels = batch["label"].to(device)
            logits, probs = model(premises, premises_lengths, hypotheses,
                                  hypotheses_lengths)
            loss = criterion(logits, labels)
            running_loss += loss.item()
            running_accuracy += correct_predictions(probs, labels)
    epoch_time = time.time() - epoch_start
    epoch_loss = running_loss / len(dataloader)
    epoch_accuracy = running_accuracy / (len(dataloader.dataset))
    return epoch_time, epoch_loss, epoch_accuracy
def output_predictions(model,
                       dataset='test',
                       filename='svhn_classification_{dataset}.txt'):
    correct = 0
    total = 0
    with open(filename.format(dataset=dataset), "w",
              encoding="utf-8") as out_file:
        # TODO: Predict the digits and their bounding boxes on the test set.
        dataset = getattr(SVHN(), dataset).map(SVHN.parse)
        mapped_dataset = dataset.map(scale_input(model.args.image_size))
        for x, xorig in zip(mapped_dataset.batch(1), dataset):
            predicted_bboxes, scores, predicted_classes, valid = model.predict_on_batch(
                x)
            num_valid = valid[0].numpy()
            predicted_bboxes = predicted_bboxes[0, :num_valid, ...].numpy()
            predicted_classes = predicted_classes[0, :num_valid, ...].numpy()
            scores = scores[0, :num_valid, ...].numpy()

            predicted_bboxes = predicted_bboxes[
                scores > model.args.score_threshold]
            predicted_classes = predicted_classes[
                scores > model.args.score_threshold]

            transformed_bboxes = []
            output = []
            for label, bbox in zip(predicted_classes, predicted_bboxes):
                output.append(label.astype(np.int32))
                bbox_transformed = tf.cast(
                    tf.shape(xorig['image'])[1],
                    tf.float32).numpy() * bbox / float(model.args.image_size)
                transformed_bboxes.append(bbox_transformed.astype(np.int32))
                output.extend(bbox_transformed.astype(np.int32))
            print(*output, file=out_file)

            correct += utils.correct_predictions(
                xorig["classes"].numpy(), xorig["bboxes"].numpy(),
                predicted_classes.astype(np.int32), transformed_bboxes)
            total += 1
    return correct / total
Beispiel #6
0
def test(args, model, dataloader, device):
    model.eval()

    time_start = time.time()
    batch_time = 0.0
    accuracy = 0.0
    first = True

    batch_i = 0
    n_high_overlap = 0
    n_reg_overlap = 0
    n_low_overlap = 0
    n_long_sentence = 0
    n_reg_sentence = 0
    n_short_sentence = 0
    n_negation = 0
    n_quantifier = 0
    n_belief = 0
    n_total_high_overlap = 0
    n_total_reg_overlap = 0
    n_total_low_overlap = 0
    n_total_long_sentence = 0
    n_total_reg_sentence = 0
    n_total_short_sentence = 0
    n_total_negation = 0
    n_total_quantifier = 0
    n_total_belief = 0
    n_total_entailment = 0
    n_total_neutral = 0
    n_total_contradiction = 0
    n_correct_entailment = 0
    n_correct_neutral = 0
    n_correct_contradiction = 0

    stat_file = args.test_statistics
    with open(stat_file, "rb") as pkl:
        test_statistics = pickle.load(pkl)

    set_idx_high_overlap = set(test_statistics["high_overlap"])
    set_idx_reg_overlap = set(test_statistics["reg_overlap"])
    set_idx_low_overlap = set(test_statistics["low_overlap"])
    set_idx_long_sentence = set(test_statistics["long_sentence"])
    set_idx_reg_sentence = set(test_statistics["reg_sentence"])
    set_idx_short_sentence = set(test_statistics["short_sentence"])
    set_idx_negation = set(test_statistics["negation"])
    set_idx_quantifier = set(test_statistics["quantifier"])
    set_idx_belief = set(test_statistics["belief"])

    # Deactivate autograd for evaluation.
    with torch.no_grad():
        for batch in dataloader:
            batch_start = time.time()

            # Move input and output data to the GPU if one is used.
            premises = batch["premise"].to(device)
            premises_lengths = batch["premise_length"].to(device)
            hypotheses = batch["hypothesis"].to(device)
            hypotheses_lengths = batch["hypothesis_length"].to(device)
            labels = batch["label"].to(device)

            _, probs = model(premises, premises_lengths, hypotheses,
                             hypotheses_lengths)

            accuracy += correct_predictions(probs, labels)
            batch_time += time.time() - batch_start
            _, out_classes = probs.max(dim=1)
            # if first:
            #     print ('Predictions for the first 5 sentences:')
            #     print ('Predictions:')
            #     print (out_classes[:5])
            #     print ('Labels:')
            #     print (labels[:5])
            #     print ('0 = entailment, 1 = neutral, 2 = contradiction')
            #     first = False

            # statistics
            for i in range(len(out_classes)):
                line_i = batch_i * 32 + i

                if labels[i] == 0:
                    n_total_entailment += 1
                elif labels[i] == 1:
                    n_total_neutral += 1
                elif labels[i] == 2:
                    n_total_contradiction += 1

                if line_i in set_idx_high_overlap:
                    n_total_high_overlap += 1
                if line_i in set_idx_reg_overlap:
                    n_total_reg_overlap += 1
                if line_i in set_idx_low_overlap:
                    n_total_low_overlap += 1
                if line_i in set_idx_long_sentence:
                    n_total_long_sentence += 1
                if line_i in set_idx_reg_sentence:
                    n_total_reg_sentence += 1
                if line_i in set_idx_short_sentence:
                    n_total_short_sentence += 1
                if line_i in set_idx_negation:
                    n_total_negation += 1
                if line_i in set_idx_quantifier:
                    n_total_quantifier += 1
                if line_i in set_idx_belief:
                    n_total_belief += 1

                if out_classes[i] == labels[i]:
                    if labels[i] == 0:
                        n_correct_entailment += 1
                    elif labels[i] == 1:
                        n_correct_neutral += 1
                    elif labels[i] == 2:
                        n_correct_contradiction += 1

                    if line_i in set_idx_high_overlap:
                        n_high_overlap += 1
                    if line_i in set_idx_reg_overlap:
                        n_reg_overlap += 1
                    if line_i in set_idx_low_overlap:
                        n_low_overlap += 1
                    if line_i in set_idx_long_sentence:
                        n_long_sentence += 1
                    if line_i in set_idx_reg_sentence:
                        n_reg_sentence += 1
                    if line_i in set_idx_short_sentence:
                        n_short_sentence += 1
                    if line_i in set_idx_negation:
                        n_negation += 1
                    if line_i in set_idx_quantifier:
                        n_quantifier += 1
                    if line_i in set_idx_belief:
                        n_belief += 1

            batch_i += 1

    # print ('Total Entailment:' + str(n_total_entailment))
    # print ('Correct Entailment:' + str(n_correct_entailment))
    # print ('Accuracy:' + str(float(n_correct_entailment) / n_total_entailment))
    # print ('Total Neutral:' + str(n_total_neutral))
    # print ('Correct Neutral:' + str(n_correct_neutral))
    # print ('Accuracy:' + str(float(n_correct_neutral) / n_total_neutral))
    #
    # print ('Total Contradiction:' + str(n_total_contradiction))
    # print ('Correct Contradiction:' + str(n_correct_contradiction))
    # print ('Accuracy:' + str(float(n_correct_contradiction) / n_total_contradiction))
    # print ('Total high overlap sentence:' + str(n_total_high_overlap))
    # print ('Correct high overlap sentence:' + str(n_high_overlap))
    # print ('Accuracy:' + str(float(n_high_overlap) / n_total_high_overlap))
    # print ('Total regular overlap sentence:' + str(n_total_reg_overlap))
    # print ('Correct regular overlap sentence:' + str(n_reg_overlap))
    # print ('Accuracy:' + str(float(n_reg_overlap) / n_total_reg_overlap))
    # print ('Total low overlap sentence:' + str(n_total_low_overlap))
    # print ('Correct low overlap sentence:' + str(n_low_overlap))
    # print ('Accuracy:' + str(float(n_low_overlap) / n_total_low_overlap))
    # print ('Total long sentence:' + str(n_total_long_sentence))
    # print ('Correct long sentence:' + str(n_long_sentence))
    # print ('Accuracy:' + str(float(n_long_sentence) / n_total_long_sentence))
    # print ('Total regular sentence:' + str(n_total_reg_sentence))
    # print ('Correct regular sentence:' + str(n_reg_sentence))
    # print ('Accuracy:' + str(float(n_reg_sentence) / n_total_reg_sentence))
    # print ('Total short sentence:' + str(n_total_short_sentence))
    # print ('Correct short sentence:' + str(n_short_sentence))
    # print ('Accuracy:' + str(float(n_short_sentence) / n_total_short_sentence))
    # print ('Total sentence with negation:' + str(n_total_negation))
    # print ('Correct sentence with negation:' + str(n_negation))
    # print ('Accuracy:' + str(float(n_negation) / n_total_negation))
    # print ('Total sentence with quantifier:' + str(n_total_quantifier))
    # print ('Correct sentence with quantifier:' + str(n_quantifier))
    # print ('Accuracy:' + str(float(n_quantifier) / n_total_quantifier))
    # print ('Total sentence with belief:' + str(n_total_belief))
    # print ('Correct sentence with belief:' + str(n_belief))
    # print ('Accuracy:' + str(float(n_belief) / n_total_belief))

    batch_time /= len(dataloader)
    total_time = time.time() - time_start
    accuracy /= (len(dataloader.dataset))

    return batch_time, total_time, accuracy
def main(args, logger):
    stat_file = args.test_statistics
    device = args.local_rank if args.local_rank != -1 else (torch.device(
        'cuda:0') if torch.cuda.is_available() else torch.device('cpu'))
    if args.local_rank != -1:
        torch.cuda.set_device(args.local_rank)
    load_path = args.load_path
    test_file = args.test_data
    embedding_file = args.embeddings
    batch_size = args.batch_size

    info = "\t* Loading testing data..."
    rank_logger_info(logger, args.local_rank, info)
    with open(test_file, "rb") as pkl:
        test_data = NLIDataset(pickle.load(pkl))
    test_loader = DataLoader(test_data, shuffle=False, batch_size=batch_size)
    with open(embedding_file, "rb") as pkl:
        embeddings = torch.tensor(pickle.load(pkl),
                                  dtype=torch.float).to(device)
    with open(stat_file, "rb") as pkl:
        test_statistics = pickle.load(pkl)

    info = "\t* Loading pretrained models..."
    rank_logger_info(logger, args.local_rank, info)
    model_path_lst = os.listdir(load_path)
    checkpoint_lst = [
        torch.load(os.path.join(load_path, pretrained_file),
                   map_location=torch.device(device))
        for pretrained_file in model_path_lst
    ]
    model_state_dict_lst = [
        checkpoint["model_state_dict"] for checkpoint in checkpoint_lst
    ]
    best_score_lst = [
        checkpoint["best_score"] for checkpoint in checkpoint_lst
    ]
    epochs_count = [
        checkpoint["epochs_count"] for checkpoint in checkpoint_lst
    ]
    model_n = len(model_state_dict_lst)
    info = "\t* Loading done : {}".format(model_n)

    rank_logger_info(logger, args.local_rank, info)
    model_lst = [
        DRLSTM(embeddings.shape[0],
               embeddings.shape[1],
               hidden_size=args.hidden_size,
               embeddings=embeddings,
               padding_idx=0,
               dropout=args.dropout,
               num_classes=args.num_classes,
               device=device,
               pooling_method_lst=args.pooling_method,
               embedding_dropout=args.embedding_dropout)
        for i in range(model_n)
    ]
    for idx, model in enumerate(model_lst):
        model.load_state_dict(model_state_dict_lst[idx])
        model.to(device)
        for params in model.parameters():
            params.requires_grad = False

    if args.ensemble_mode == 1 or args.ensemble_mode == 2:
        info = "\t* training..."
        rank_logger_info(logger, args.local_rank, info)
        if args.ensemble_mode == 1:
            ensemble_model = Ensemble_model1(model_n)
        else:
            ensemble_model = Ensemble_model2(model_n)
        ensemble_model.to(device)
        with open(args.valid_data, "rb") as pkl:
            valid_data = NLIDataset(pickle.load(pkl))
        valid_loader = DataLoader(valid_data,
                                  shuffle=False,
                                  batch_size=args.batch_size)
        criterion = nn.CrossEntropyLoss()
        if args.optim == "adam":
            optimizer = torch.optim.Adam(ensemble_model.parameters(),
                                         lr=args.lr)
        elif args.optim == "rmsprop":
            optimizer = torch.optim.RMSprop(ensemble_model.parameters(),
                                            lr=args.lr)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                               mode="max",
                                                               factor=0.5,
                                                               patience=0)
        best_score = 0.0
        start_epoch = 1

        # Data for loss curves plot.
        epochs_count = []
        train_losses = []
        valid_losses = []
        train_accuracy = []
        valid_accuracy = []
        info = "\n" + 20 * "=" + "Training model on device: {}".format(
            device) + 20 * "="
        rank_logger_info(logger, args.local_rank, info)
        patience_counter = 0
        for epoch in range(start_epoch, args.epochs + 1):
            epochs_count.append(epoch)
            info = "* Training epoch {}:".format(epoch)
            rank_logger_info(logger, local_rank, info)
            epoch_time, epoch_loss, epoch_accuracy = multi_model_train(
                args, epoch, ensemble_model, model_lst, valid_loader,
                optimizer, criterion, args.max_gradient_norm, device)
            train_losses.append(epoch_loss)
            train_accuracy.append(epoch_accuracy)
            info = "Training epoch: {}, time: {:.4f}s, loss: {:.4f}, accuracy: {:.4f}%\n".format(
                epoch, epoch_time, epoch_loss, (epoch_accuracy * 100))
            rank_logger_info(logger, args.local_rank, info)
            weight_lst = ensemble_model.weight_layer.weight.data.cpu().numpy(
            ).tolist()[0]
            rank_logger_info(logger, args.local_rank, weight_lst)
            info = "* Validation for epoch {}:".format(epoch)
            rank_logger_info(logger, local_rank, info)
            epoch_time, epoch_loss, epoch_accuracy = multi_model_valid(
                ensemble_model, model_lst, test_loader, criterion, device)
            valid_losses.append(epoch_loss)
            valid_accuracy.append(epoch_accuracy)
            info = "Validing epoch: {}, time: {:.4f}s, loss: {:.4f}, accuracy: {:.4f}%\n".format(
                epoch, epoch_time, epoch_loss, (epoch_accuracy * 100))
            rank_logger_info(logger, args.local_rank, info)
            scheduler.step(epoch_accuracy)
            if epoch_accuracy <= best_score:
                patience_counter += 1
            else:
                best_score = epoch_accuracy
                best_model = ensemble_model
                patience_counter = 0
                if args.local_rank in [-1, 0]:
                    torch.save(
                        {
                            "epoch": epoch,
                            "model_state_dict": best_model.state_dict(),
                            "best_score": best_score,
                            "epochs_count": epochs_count,
                            "train_losses": train_losses,
                            "valid_losses": valid_losses
                        }, os.path.join(args.save_path, "best.pth.tar"))
            if patience_counter >= args.patience:
                info = "-> Early stopping: patience limit reached, stopping..."
                rank_logger_info(logger, args.local_rank, info)
                break
        if args.local_rank in [-1, 0]:
            report_result(epochs_count, train_losses, valid_losses,
                          train_accuracy, valid_accuracy, args.save_path)
    else:  # 模式3 4
        info = "\t* testing..."
        rank_logger_info(logger, args.local_rank, info)
        if args.ensemble_mode == 3:
            ensemble_model = Ensemble_model3(model_n)
        else:
            ensemble_model = Ensemble_model4(model_n)
        ensemble_model.to(device)
        ensemble_model.eval()
        time_start = time.time()
        batch_time = 0.0
        accuracy = 0.0
        with torch.no_grad():
            for batch in test_loader:
                batch_start = time.time()
                premises = batch["premise"].to(device)
                premises_lengths = batch["premise_length"].to(device)
                hypotheses = batch["hypothesis"].to(device)
                hypotheses_lengths = batch["hypothesis_length"].to(device)
                labels = batch["label"].to(device)
                logits_probs_lst = [
                    model(premises, premises_lengths, hypotheses,
                          hypotheses_lengths) for model in model_lst
                ]
                logits_lst = [i[0].unsqueeze(1) for i in logits_probs_lst]
                probs_lst = [i[1].unsqueeze(1) for i in logits_probs_lst]
                _, probs = ensemble_model(logits_lst, probs_lst)
                accuracy += correct_predictions(probs, labels)
                batch_time += time.time() - batch_start
                _, out_classes = probs.max(dim=1)
            batch_time /= len(test_loader)
            total_time = time.time() - time_start
            accuracy /= (len(test_loader.dataset))
            info = "-> Average batch processing time: {:.4f}s, total test time:\
             {:.4f}s, accuracy: {:.4f}%".format(batch_time, total_time,
                                                (accuracy * 100))
            rank_logger_info(logger, args.local_rank, info)
    def fit(self):
        self._start_wandb()
        dataset = self.dataset.shuffle(3000) \
            .batch(args.batch_size) \
            .prefetch(4)

        val_dataset = self.val_dataset \
            .batch(args.batch_size) \
            .prefetch(4)

        for epoch in range(self.args.epochs):
            self._epoch.assign(epoch)
            log_append = dict()

            # Reset metrics
            for m in self.metrics.values():
                m.reset_states()

            # Train on train dataset
            for epoch_step, x in enumerate(dataset):
                self._epoch_step.assign(epoch_step)
                loss, regression_loss, class_loss = self.train_on_batch(x)
                self.metrics['loss'].update_state(loss)
                self.metrics['regression_loss'].update_state(regression_loss)
                self.metrics['class_loss'].update_state(class_loss)

            # Run validation
            for x in val_dataset:
                loss, regression_loss, class_loss = self.evaluate_on_batch(x)
                self.metrics['val_loss'].update_state(loss)
                self.metrics['val_regression_loss'].update_state(
                    regression_loss)
                self.metrics['val_class_loss'].update_state(class_loss)

            # Compute straka's metric
            # TODO: vectorize Straka's metric
            predictions = self.predict(self.val_dataset)
            for (boxes, classes, scores), gold in zip(predictions,
                                                      self.val_dataset):
                gold_classes, gold_boxes = gold['gt-class'].numpy(
                ), gold['gt-bbox'].numpy()
                num_gt = gold['gt-length'].numpy()
                gold_classes, gold_boxes = gold_classes[:
                                                        num_gt], gold_boxes[:
                                                                            num_gt]
                boxes, classes = boxes[
                    scores > self.args.score_threshold], classes[
                        scores > self.args.score_threshold]
                self.metrics['val_score'].update_state(
                    utils.correct_predictions(gold_classes, gold_boxes,
                                              classes, boxes))

            # mAP metric should be implemented here. Note, that predictions
            # That are generated use transformed bb, i.e., the bb is scaled to args.image_size
            # the original dataset has different image sizes and this needs to be taken care of
            # in the metric
            predictions = self.predict(self.val_dataset)
            self.coco_metric.evaluate(predictions)

            # Save model every 20 epochs
            if (epoch + 1) % 20 == 0:
                self.save()
                val_acc = output_predictions(self, 'dev')
                output_predictions(self, 'test')
                if hasattr(self, '_wandb'):
                    self._wandb.save('svhn_classification_dev.txt')
                    self._wandb.save('svhn_classification_test.txt')
                print('model saved')
                print(f'validation score: {val_acc * 100:.2f}')
                log_append = dict(saved_val_score=val_acc, **log_append)

            # Log current values
            self.log(**log_append)