コード例 #1
0
    def __init__(
        self,
        model: nn.Module,
        model_filepath: str,
        datasets_manager: DatasetsManager,
        tokens_namespace: str = "tokens",
        normalized_probs_namespace: str = "normalized_probs",
    ):

        super(ClassificationInference, self).__init__(
            model=model,
            model_filepath=model_filepath,
            datasets_manager=datasets_manager,
        )
        self.batch_size = 32
        self.tokens_namespace = tokens_namespace
        self.normalized_probs_namespace = normalized_probs_namespace
        self.label_namespace = self.datasets_manager.label_namespaces[0]

        self.labelname2idx_mapping = self.datasets_manager.get_label_idx_mapping(
            label_namespace=self.label_namespace)
        self.idx2labelname_mapping = self.datasets_manager.get_idx_label_mapping(
            label_namespace=self.label_namespace)

        self.load_model()

        self.metrics_calculator = PrecisionRecallFMeasure(
            datasets_manager=datasets_manager)
        self.output_analytics = None

        # create a dataframe with all the information
        self.output_df = None
コード例 #2
0
    def test_classifier_produces_correct_precision(self,
                                                   setup_simple_classifier):
        iter_dict, simple_classifier, batch_size, num_classes = setup_simple_classifier
        output = simple_classifier(iter_dict,
                                   is_training=True,
                                   is_validation=False,
                                   is_test=False)
        idx2labelname_mapping = {
            0: "good class",
            1: "bad class",
            2: "average_class"
        }
        metrics_calc = PrecisionRecallFMeasure(
            idx2labelname_mapping=idx2labelname_mapping)

        metrics_calc.calc_metric(iter_dict=iter_dict,
                                 model_forward_dict=output)
        metrics = metrics_calc.get_metric()
        precision = metrics["precision"]

        # NOTE: topk returns the last value in the dimension incase
        # all the values are equal.
        expected_precision = {1: 0, 2: 0}

        assert len(precision) == 2

        for class_label, precision_value in precision.items():
            assert precision_value == expected_precision[class_label]
コード例 #3
0
def setup_sectlabel_bow_glove_infer(request, clf_datasets_manager,
                                    tmpdir_factory):
    track_for_best = request.param
    sample_proportion = 0.5
    datasets_manager = clf_datasets_manager
    word_embedder = WordEmbedder(embedding_type="glove_6B_50")
    bow_encoder = BOW_Encoder(embedder=word_embedder)
    classifier = SimpleClassifier(
        encoder=bow_encoder,
        encoding_dim=word_embedder.get_embedding_dimension(),
        num_classes=2,
        classification_layer_bias=True,
        datasets_manager=datasets_manager,
    )
    train_metric = PrecisionRecallFMeasure(datasets_manager=datasets_manager)
    validation_metric = PrecisionRecallFMeasure(
        datasets_manager=datasets_manager)
    test_metric = PrecisionRecallFMeasure(datasets_manager=datasets_manager)

    optimizer = torch.optim.Adam(params=classifier.parameters())
    batch_size = 1
    save_dir = tmpdir_factory.mktemp("experiment_1")
    num_epochs = 1
    save_every = 1
    log_train_metrics_every = 10

    engine = Engine(
        model=classifier,
        datasets_manager=datasets_manager,
        optimizer=optimizer,
        batch_size=batch_size,
        save_dir=save_dir,
        num_epochs=num_epochs,
        save_every=save_every,
        log_train_metrics_every=log_train_metrics_every,
        train_metric=train_metric,
        validation_metric=validation_metric,
        test_metric=test_metric,
        track_for_best=track_for_best,
        sample_proportion=sample_proportion,
    )

    engine.run()
    model_filepath = pathlib.Path(save_dir).joinpath("best_model.pt")
    infer = ClassificationInference(
        model=classifier,
        model_filepath=str(model_filepath),
        datasets_manager=datasets_manager,
    )
    return infer
コード例 #4
0
def setup_data_one_true_class_missing():
    """
    The batch of instances during training might not have all
    true classes. What happens in that case??
    The test case here captures the situation
    :return:
    """
    predicted_probs = torch.FloatTensor([[0.8, 0.1, 0.2], [0.2, 0.5, 0.3]])
    idx2labelname_mapping = {
        0: "good class",
        1: "bad class",
        2: "average_class"
    }
    labels = torch.LongTensor([0, 2]).view(-1, 1)

    expected_precision = {0: 1.0, 1: 0.0, 2: 0.0}
    expected_recall = {0: 1.0, 1: 0.0, 2: 0.0}
    expected_fscore = {0: 1.0, 1: 0.0, 2: 0.0}

    accuracy = PrecisionRecallFMeasure(
        idx2labelname_mapping=idx2labelname_mapping)

    return (
        predicted_probs,
        labels,
        accuracy,
        {
            "expected_precision": expected_precision,
            "expected_recall": expected_recall,
            "expected_fscore": expected_fscore,
        },
    )
コード例 #5
0
    def __init__(
        self, model: nn.Module, model_filepath: str, dataset: BaseTextClassification
    ):

        super(ClassificationInference, self).__init__(
            model=model, model_filepath=model_filepath, dataset=dataset
        )
        self.batch_size = 32

        self.labelname2idx_mapping = self.dataset.get_classname2idx()
        self.idx2labelname_mapping = {
            idx: label_name for label_name, idx in self.labelname2idx_mapping.items()
        }
        self.load_model()
        self.metrics_calculator = PrecisionRecallFMeasure(
            idx2labelname_mapping=self.idx2labelname_mapping
        )
        self.output_analytics = None

        # create a dataframe with all the information
        self.output_df = None
コード例 #6
0
def setup_data_to_test_length():
    predicted_probs = torch.FloatTensor([[0.1, 0.8, 0.2], [0.2, 0.3, 0.5]])
    labels = torch.LongTensor([0, 2]).view(-1, 1)
    idx2labelname_mapping = {
        0: "good class",
        1: "bad class",
        2: "average_class"
    }

    accuracy = PrecisionRecallFMeasure(
        idx2labelname_mapping=idx2labelname_mapping)

    expected_length = 3

    return predicted_probs, labels, accuracy, expected_length
コード例 #7
0
def setup_data_for_all_zeros(clf_dataset_manager):
    predicted_probs = torch.FloatTensor([[0.9, 0.1], [0.3, 0.7]])
    datasets_manager = clf_dataset_manager
    labels = torch.LongTensor([1, 0]).view(-1, 1)

    expected_precision = {0: 0.0, 1: 0.0}
    expected_recall = {0: 0.0, 1: 0.0}
    expected_fmeasure = {0: 0.0, 1: 0.0}
    expected_macro_precision = 0.0
    expected_macro_recall = 0.0
    expected_macro_fscore = 0.0
    expected_num_tps = {0: 0.0, 1: 0.0}
    expected_num_fps = {0: 1.0, 1: 1.0}
    expected_num_fns = {0: 1.0, 1: 1.0}
    expected_micro_precision = 0.0
    expected_micro_recall = 0.0
    expected_micro_fscore = 0.0

    prf_metric = PrecisionRecallFMeasure(datasets_manager=datasets_manager)
    return (
        predicted_probs,
        labels,
        prf_metric,
        datasets_manager,
        {
            "expected_precision": expected_precision,
            "expected_recall": expected_recall,
            "expected_fscore": expected_fmeasure,
            "expected_macro_precision": expected_macro_precision,
            "expected_macro_recall": expected_macro_recall,
            "expected_macro_fscore": expected_macro_fscore,
            "expected_num_tps": expected_num_tps,
            "expected_num_fps": expected_num_fps,
            "expected_num_fns": expected_num_fns,
            "expected_micro_precision": expected_micro_precision,
            "expected_micro_recall": expected_micro_recall,
            "expected_micro_fscore": expected_micro_fscore,
        },
    )
コード例 #8
0
def setup_data_basecase(clf_dataset_manager):
    dataset_manager = clf_dataset_manager
    prf_metric = PrecisionRecallFMeasure(dataset_manager)
    predicted_probs = torch.FloatTensor([[0.1, 0.9], [0.7, 0.3]])
    labels = torch.LongTensor([1, 0]).view(-1, 1)

    expected_precision = {0: 1.0, 1: 1.0}
    expected_recall = {0: 1.0, 1: 1.0}
    expected_fmeasure = {0: 1.0, 1: 1.0}
    expected_macro_precision = 1.0
    expected_macro_recall = 1.0
    expected_macro_fscore = 1.0
    expected_num_tps = {0: 1.0, 1: 1.0}
    expected_num_fps = {0: 0.0, 1: 0.0}
    expected_num_fns = {0: 0.0, 1: 0.0}
    expected_micro_precision = 1.0
    expected_micro_recall = 1.0
    expected_micro_fscore = 1.0

    return (
        predicted_probs,
        labels,
        prf_metric,
        dataset_manager,
        {
            "expected_precision": expected_precision,
            "expected_recall": expected_recall,
            "expected_fscore": expected_fmeasure,
            "expected_macro_precision": expected_macro_precision,
            "expected_macro_recall": expected_macro_recall,
            "expected_macro_fscore": expected_macro_fscore,
            "expected_num_tps": expected_num_tps,
            "expected_num_fps": expected_num_fps,
            "expected_num_fns": expected_num_fns,
            "expected_micro_precision": expected_micro_precision,
            "expected_micro_recall": expected_micro_recall,
            "expected_micro_fscore": expected_micro_fscore,
        },
    )
コード例 #9
0
def setup_data_basecase():
    predicted_probs = torch.FloatTensor([[0.1, 0.9], [0.7, 0.3]])
    labels = torch.LongTensor([1, 0]).view(-1, 1)
    idx2labelname_mapping = {0: "good class", 1: "bad class"}

    expected_precision = {0: 1.0, 1: 1.0}
    expected_recall = {0: 1.0, 1: 1.0}
    expected_fmeasure = {0: 1.0, 1: 1.0}
    expected_macro_precision = 1.0
    expected_macro_recall = 1.0
    expected_macro_fscore = 1.0
    expected_num_tps = {0: 1.0, 1: 1.0}
    expected_num_fps = {0: 0.0, 1: 0.0}
    expected_num_fns = {0: 0.0, 1: 0.0}
    expected_micro_precision = 1.0
    expected_micro_recall = 1.0
    expected_micro_fscore = 1.0

    accuracy = PrecisionRecallFMeasure(
        idx2labelname_mapping=idx2labelname_mapping)
    return (
        predicted_probs,
        labels,
        accuracy,
        {
            "expected_precision": expected_precision,
            "expected_recall": expected_recall,
            "expected_fscore": expected_fmeasure,
            "expected_macro_precision": expected_macro_precision,
            "expected_macro_recall": expected_macro_recall,
            "expected_macro_fscore": expected_macro_fscore,
            "expected_num_tps": expected_num_tps,
            "expected_num_fps": expected_num_fps,
            "expected_num_fns": expected_num_fns,
            "expected_micro_precision": expected_micro_precision,
            "expected_micro_recall": expected_micro_recall,
            "expected_micro_fscore": expected_micro_fscore,
        },
    )
コード例 #10
0
        combine_strategy=COMBINE_STRATEGY,
        device=torch.device(DEVICE),
    )

    encoding_dim = (2 * HIDDEN_DIMENSION if BIDIRECTIONAL
                    and COMBINE_STRATEGY == "concat" else HIDDEN_DIMENSION)

    model = SimpleClassifier(
        encoder=encoder,
        encoding_dim=encoding_dim,
        num_classes=NUM_CLASSES,
        classification_layer_bias=True,
    )

    optimizer = optim.Adam(params=model.parameters(), lr=LEARNING_RATE)
    metric = PrecisionRecallFMeasure(
        idx2labelname_mapping=train_dataset.idx2classname)

    engine = Engine(
        model=model,
        train_dataset=train_dataset,
        validation_dataset=validation_dataset,
        test_dataset=test_dataset,
        optimizer=optimizer,
        batch_size=BATCH_SIZE,
        save_dir=MODEL_SAVE_DIR,
        num_epochs=NUM_EPOCHS,
        save_every=SAVE_EVERY,
        log_train_metrics_every=LOG_TRAIN_METRICS_EVERY,
        tensorboard_logdir=TENSORBOARD_LOGDIR,
        device=torch.device(DEVICE),
        metric=metric,
コード例 #11
0
    encoder = BOW_Encoder(embedder=embedder,
                          aggregation_type=args.word_aggregation,
                          device=args.device)

    model = SimpleClassifier(
        encoder=encoder,
        encoding_dim=1024,
        num_classes=data_manager.num_labels["label"],
        classification_layer_bias=True,
        datasets_manager=data_manager,
        device=args.device,
    )

    optimizer = optim.Adam(params=model.parameters(), lr=args.lr)
    train_metric = PrecisionRecallFMeasure(datasets_manager=data_manager)
    dev_metric = PrecisionRecallFMeasure(datasets_manager=data_manager)
    test_metric = PrecisionRecallFMeasure(datasets_manager=data_manager)

    engine = Engine(
        model=model,
        datasets_manager=data_manager,
        optimizer=optimizer,
        batch_size=args.bs,
        save_dir=args.model_save_dir,
        num_epochs=args.epochs,
        save_every=args.save_every,
        log_train_metrics_every=args.log_train_metrics_every,
        device=args.device,
        train_metric=train_metric,
        validation_metric=dev_metric,
コード例 #12
0
class ClassificationInference(BaseClassificationInference):
    """
    The sciwing engine runs the test lines through the classifier
    and returns the predictions/probabilities for different classes
    At a later point in time this method should be able to take any
    context of lines (may be from a file) and produce the output.

    This class also helps in performing various interactions with
    the results on the test dataset.
    Some features are
    1) Show confusion matrix
    2) Investigate a particular example in the test dataset
    3) Get instances that were classified as 2 when their true label is 1 and others

    All it needs is the configuration file stored under every experiment to have a
    vocab already stored in the experiment folder
    """
    def __init__(
        self,
        model: nn.Module,
        model_filepath: str,
        datasets_manager: DatasetsManager,
        tokens_namespace: str = "tokens",
        normalized_probs_namespace: str = "normalized_probs",
    ):

        super(ClassificationInference, self).__init__(
            model=model,
            model_filepath=model_filepath,
            datasets_manager=datasets_manager,
        )
        self.batch_size = 32
        self.tokens_namespace = tokens_namespace
        self.normalized_probs_namespace = normalized_probs_namespace
        self.label_namespace = self.datasets_manager.label_namespaces[0]

        self.labelname2idx_mapping = self.datasets_manager.get_label_idx_mapping(
            label_namespace=self.label_namespace)
        self.idx2labelname_mapping = self.datasets_manager.get_idx_label_mapping(
            label_namespace=self.label_namespace)

        self.load_model()

        self.metrics_calculator = PrecisionRecallFMeasure(
            datasets_manager=datasets_manager)
        self.output_analytics = None

        # create a dataframe with all the information
        self.output_df = None

    def run_inference(self) -> Dict[str, Any]:

        with self.msg_printer.loading(text="Running inference on test data"):
            loader = DataLoader(
                dataset=self.datasets_manager.test_dataset,
                batch_size=self.batch_size,
                shuffle=False,
                collate_fn=list,
            )
            output_analytics = {}

            # contains the predicted class names for all the instances
            pred_class_names = []
            true_class_names = [
            ]  # contains the true class names for all the instances
            sentences = []  # batch sentences in english
            true_labels_indices = []
            predicted_labels_indices = []
            all_pred_probs = []
            self.metrics_calculator.reset()

            for lines_labels in loader:
                lines_labels = list(zip(*lines_labels))
                lines = lines_labels[0]
                labels = lines_labels[1]

                batch_sentences = [line.text for line in lines]
                model_output_dict = self.model_forward_on_lines(lines=lines)
                normalized_probs = model_output_dict[
                    self.normalized_probs_namespace]
                self.metrics_calculator.calc_metric(
                    lines=lines,
                    labels=labels,
                    model_forward_dict=model_output_dict)
                true_label_ind, true_label_names = self.get_true_label_indices_names(
                    labels=labels)
                (
                    pred_label_indices,
                    pred_label_names,
                ) = self.model_output_dict_to_prediction_indices_names(
                    model_output_dict=model_output_dict)

                true_label_ind = torch.LongTensor(true_label_ind)
                true_labels_indices.append(true_label_ind)
                true_class_names.extend(true_label_names)
                predicted_labels_indices.extend(pred_label_indices)
                pred_class_names.extend(pred_label_names)
                sentences.extend(batch_sentences)
                all_pred_probs.append(normalized_probs)

            # contains predicted probs for all the instances
            all_pred_probs = torch.cat(all_pred_probs, dim=0)
            true_labels_indices = torch.cat(true_labels_indices,
                                            dim=0).squeeze()

            # torch.LongTensor N, 1
            output_analytics["true_labels_indices"] = true_labels_indices
            output_analytics[
                "predicted_labels_indices"] = predicted_labels_indices
            output_analytics["pred_class_names"] = pred_class_names
            output_analytics["true_class_names"] = true_class_names
            output_analytics["sentences"] = sentences
            output_analytics["all_pred_probs"] = all_pred_probs

        self.msg_printer.good(title="Finished running inference")
        return output_analytics

    def model_forward_on_lines(self, lines: List[Line]):
        with torch.no_grad():
            model_output_dict = self.model(lines=lines,
                                           is_training=False,
                                           is_validation=False,
                                           is_test=True)
        return model_output_dict

    def get_misclassified_sentences(self, true_label_idx: int,
                                    pred_label_idx: int):
        """This returns the true label misclassified as
        pred label idx

        Parameters
        ----------
        true_label_idx : int
            The label index of the true class name
        pred_label_idx : int
            The label index of the predicted class name


        Returns
        -------
        List[str]
            A list of strings where the true class is classified as pred class.

        """

        instances_idx = self.output_df[
            self.output_df["true_labels_indices"].isin([true_label_idx])
            & self.output_df["predicted_labels_indices"].isin(
                [pred_label_idx])].index.tolist()

        for idx in instances_idx:
            sentence = self.output_analytics["sentences"][idx]

            if true_label_idx != pred_label_idx:
                stylized_sentence = self.msg_printer.text(
                    title=sentence,
                    icon=MESSAGES.FAIL,
                    color=MESSAGES.FAIL,
                    no_print=True,
                )
            else:
                stylized_sentence = self.msg_printer.text(
                    title=sentence,
                    icon=MESSAGES.GOOD,
                    color=MESSAGES.GOOD,
                    no_print=True,
                )

            print(stylized_sentence)

    def print_confusion_matrix(self) -> None:
        """ Prints the confusion matrix for the test dataset
        """
        self.metrics_calculator.print_confusion_metrics(
            predicted_probs=self.output_analytics["all_pred_probs"],
            labels=self.output_analytics["true_labels_indices"].unsqueeze(1),
        )

    def report_metrics(self):
        metrics = self.metrics_calculator.report_metrics()
        for namespace, table in metrics.items():
            self.msg_printer.divider(f"Results for {namespace.upper()}")
            print(table)

    @deprecated(
        reason="This method is deprecated. It will be removed in version 0.1")
    def generate_report_for_paper(self):
        """ Generates just the fscore to be used in reporting on print

        """
        paper_report = self.metrics_calculator.report_metrics(
            report_type="paper")
        class_numbers = sorted(self.idx2labelname_mapping.keys(),
                               reverse=False)
        row_names = [
            f"class_{class_num} - ({self.idx2labelname_mapping[class_num]})"
            for class_num in class_numbers
        ]
        row_names.extend([f"Micro-Fscore", f"Macro-Fscore"])
        return paper_report, row_names

    def model_output_dict_to_prediction_indices_names(
            self, model_output_dict: Dict[str, Any]) -> (List[int], List[str]):
        normalized_probs = model_output_dict["normalized_probs"]
        pred_probs, pred_indices = torch.topk(normalized_probs, k=1, dim=1)
        pred_indices = pred_indices.squeeze(1).tolist()
        pred_classnames = [
            self.idx2labelname_mapping[pred_index]
            for pred_index in pred_indices
        ]
        return pred_indices, pred_classnames

    def infer_batch(self, lines: List[str]) -> List[str]:
        """ Runs inference on a batch of lines
        This method can be used for applications. When APIS are being developed
        to serve over the web or when terminal applications are being written
        to read from files and infer, this method comes in handy


        Parameters
        ----------
        lines : List[str]
            List of text spans to be infered

        Returns
        -------
        List[str]
            Reutrns the class names for all the sentences in the input

        """
        lines = [self.datasets_manager.make_line(line=line) for line in lines]

        model_output_dict = self.model_forward_on_lines(lines=lines)
        _, pred_classnames = self.model_output_dict_to_prediction_indices_names(
            model_output_dict=model_output_dict)
        return pred_classnames

    def on_user_input(self, line: str) -> str:
        """ Runs the inference when the user inputs a single sentence either on the terminal
        or some other application

        Parameters
        ----------
        line : str
            The line entered by the user

        Returns
        -------
        str
            The class label that is infered for the user input

        """
        return self.infer_batch(lines=[line])[0]

    def get_true_label_indices_names(
            self, labels: List[Label]) -> (List[int], List[str]):
        label_names = [label.text for label in labels]
        label_indices = [
            self.labelname2idx_mapping[label_name]
            for label_name in label_names
        ]
        return label_indices, label_names

    def run_test(self):
        """ Runs inference and reports test metrics
        """
        self.output_analytics = self.run_inference()
        self.output_df = pd.DataFrame(self.output_analytics)
コード例 #13
0
        hidden_dim=HIDDEN_DIMENSION,
        combine_strategy=COMBINE_STRATEGY,
        bidirectional=BIDIRECTIONAL,
        device=torch.device(DEVICE),
    )

    classiier_encoding_dim = 2 * HIDDEN_DIMENSION if BIDIRECTIONAL else HIDDEN_DIMENSION
    model = SimpleClassifier(
        encoder=encoder,
        encoding_dim=classiier_encoding_dim,
        num_classes=NUM_CLASSES,
        classification_layer_bias=True,
    )

    optimizer = optim.Adam(params=model.parameters(), lr=LEARNING_RATE)
    metric = PrecisionRecallFMeasure(train_dataset.idx2classname)

    engine = Engine(
        model=model,
        train_dataset=train_dataset,
        validation_dataset=validation_dataset,
        test_dataset=test_dataset,
        optimizer=optimizer,
        batch_size=BATCH_SIZE,
        save_dir=MODEL_SAVE_DIR,
        num_epochs=NUM_EPOCHS,
        save_every=SAVE_EVERY,
        log_train_metrics_every=LOG_TRAIN_METRICS_EVERY,
        device=torch.device(DEVICE),
        metric=metric,
        use_wandb=True,
コード例 #14
0
def setup_engine_test_with_simple_classifier(request, tmpdir_factory):
    MAX_NUM_WORDS = 1000
    MAX_LENGTH = 50
    vocab_store_location = tmpdir_factory.mktemp("tempdir").join("vocab.json")
    DEBUG = True
    BATCH_SIZE = 1
    NUM_TOKENS = 3
    EMB_DIM = 300

    train_dataset = SectLabelDataset(
        filename=SECT_LABEL_FILE,
        dataset_type="train",
        max_num_words=MAX_NUM_WORDS,
        max_instance_length=MAX_LENGTH,
        word_vocab_store_location=vocab_store_location,
        debug=DEBUG,
        word_embedding_type="random",
        word_embedding_dimension=EMB_DIM,
    )

    validation_dataset = SectLabelDataset(
        filename=SECT_LABEL_FILE,
        dataset_type="valid",
        max_num_words=MAX_NUM_WORDS,
        max_instance_length=MAX_LENGTH,
        word_vocab_store_location=vocab_store_location,
        debug=DEBUG,
        word_embedding_type="random",
        word_embedding_dimension=EMB_DIM,
    )

    test_dataset = SectLabelDataset(
        filename=SECT_LABEL_FILE,
        dataset_type="test",
        max_num_words=MAX_NUM_WORDS,
        max_instance_length=MAX_LENGTH,
        word_vocab_store_location=vocab_store_location,
        debug=DEBUG,
        word_embedding_type="random",
        word_embedding_dimension=EMB_DIM,
    )

    VOCAB_SIZE = MAX_NUM_WORDS + len(train_dataset.word_vocab.special_vocab)
    NUM_CLASSES = train_dataset.get_num_classes()
    NUM_EPOCHS = 1
    embedding = Embedding.from_pretrained(torch.zeros([VOCAB_SIZE, EMB_DIM]))
    labels = torch.LongTensor([1])
    metric = PrecisionRecallFMeasure(
        idx2labelname_mapping=train_dataset.idx2classname)
    embedder = VanillaEmbedder(embedding_dim=EMB_DIM, embedding=embedding)
    encoder = BOW_Encoder(emb_dim=EMB_DIM,
                          embedder=embedder,
                          dropout_value=0,
                          aggregation_type="sum")
    tokens = np.random.randint(0,
                               VOCAB_SIZE - 1,
                               size=(BATCH_SIZE, NUM_TOKENS))
    tokens = torch.LongTensor(tokens)
    model = SimpleClassifier(
        encoder=encoder,
        encoding_dim=EMB_DIM,
        num_classes=NUM_CLASSES,
        classification_layer_bias=False,
    )

    optimizer = optim.SGD(model.parameters(), lr=0.01)
    engine = Engine(
        model,
        train_dataset,
        validation_dataset,
        test_dataset,
        optimizer=optimizer,
        batch_size=BATCH_SIZE,
        save_dir=tmpdir_factory.mktemp("model_save"),
        num_epochs=NUM_EPOCHS,
        save_every=1,
        log_train_metrics_every=10,
        metric=metric,
        track_for_best=request.param,
    )

    options = {
        "MAX_NUM_WORDS": MAX_NUM_WORDS,
        "MAX_LENGTH": MAX_LENGTH,
        "BATCH_SIZE": BATCH_SIZE,
        "NUM_TOKENS": NUM_TOKENS,
        "EMB_DIM": EMB_DIM,
        "VOCAB_SIZE": VOCAB_SIZE,
        "NUM_CLASSES": NUM_CLASSES,
        "NUM_EPOCHS": NUM_EPOCHS,
    }

    return engine, tokens, labels, options