Ejemplo n.º 1
0
    def test_read_from_file(self, lazy):
        reader = SquadReader(lazy=lazy)
        instances = ensure_list(reader.read(AllenNlpTestCase.FIXTURES_ROOT / 'data' / 'squad.json'))
        assert len(instances) == 5

        assert [t.text for t in instances[0].fields["question"].tokens[:3]] == ["To", "whom", "did"]
        assert [t.text for t in instances[0].fields["passage"].tokens[:3]] == ["Architecturally", ",", "the"]
        assert [t.text for t in instances[0].fields["passage"].tokens[-3:]] == ["of", "Mary", "."]
        assert instances[0].fields["span_start"].sequence_index == 102
        assert instances[0].fields["span_end"].sequence_index == 104

        assert [t.text for t in instances[1].fields["question"].tokens[:3]] == ["What", "sits", "on"]
        assert [t.text for t in instances[1].fields["passage"].tokens[:3]] == ["Architecturally", ",", "the"]
        assert [t.text for t in instances[1].fields["passage"].tokens[-3:]] == ["of", "Mary", "."]
        assert instances[1].fields["span_start"].sequence_index == 17
        assert instances[1].fields["span_end"].sequence_index == 23

        # We're checking this case because I changed the answer text to only have a partial
        # annotation for the last token, which happens occasionally in the training data.  We're
        # making sure we get a reasonable output in that case here.
        assert ([t.text for t in instances[3].fields["question"].tokens[:3]] ==
                ["Which", "individual", "worked"])
        assert [t.text for t in instances[3].fields["passage"].tokens[:3]] == ["In", "1882", ","]
        assert [t.text for t in instances[3].fields["passage"].tokens[-3:]] == ["Nuclear", "Astrophysics", "."]
        span_start = instances[3].fields["span_start"].sequence_index
        span_end = instances[3].fields["span_end"].sequence_index
        answer_tokens = instances[3].fields["passage"].tokens[span_start:(span_end + 1)]
        expected_answer_tokens = ["Father", "Julius", "Nieuwland"]
        assert [t.text for t in answer_tokens] == expected_answer_tokens
Ejemplo n.º 2
0
    def test_forward(self):
        lr = 0.5
        batch_size = 16
        embedding_dim = 50

        squad_reader = SquadReader()
        # Read SQuAD train set (use the test set, since it's smaller)
        train_dataset = squad_reader.read(self.squad_test)
        vocab = Vocabulary.from_dataset(train_dataset)
        train_dataset.index_instances(vocab)

        # Random embeddings for test
        test_embed_matrix = torch.rand(vocab.get_vocab_size(), embedding_dim)
        test_cbow = CBOW(test_embed_matrix)
        optimizer = optim.Adadelta(filter(lambda p: p.requires_grad,
                                          test_cbow.parameters()),
                                   lr=lr)

        iterator = BucketIterator(batch_size=batch_size,
                                  sorting_keys=[("passage", "num_tokens"),
                                                ("question", "num_tokens")])
        for batch in iterator(train_dataset, num_epochs=1):
            passage = batch["passage"]["tokens"]
            question = batch["question"]["tokens"]
            span_start = batch["span_start"]
            span_end = batch["span_end"]
            output_dict = test_cbow(passage, question)
            softmax_start_logits = output_dict["softmax_start_logits"]
            softmax_end_logits = output_dict["softmax_end_logits"]
            loss = nll_loss(softmax_start_logits, span_start.view(-1))
            loss += nll_loss(softmax_end_logits, span_end.view(-1))
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
Ejemplo n.º 3
0
    def test_read_from_file(self):
        reader = SquadReader()
        instances = reader.read('tests/fixtures/data/squad.json').instances
        assert len(instances) == 5

        assert instances[0].fields["question"].tokens[:3] == ["To", "whom", "did"]
        assert instances[0].fields["passage"].tokens[:3] == ["Architecturally", ",", "the"]
        assert instances[0].fields["passage"].tokens[-3:] == ["of", "Mary", "."]
        assert instances[0].fields["span_start"].sequence_index == 102
        assert instances[0].fields["span_end"].sequence_index == 104

        assert instances[1].fields["question"].tokens[:3] == ["What", "sits", "on"]
        assert instances[1].fields["passage"].tokens[:3] == ["Architecturally", ",", "the"]
        assert instances[1].fields["passage"].tokens[-3:] == ["of", "Mary", "."]
        assert instances[1].fields["span_start"].sequence_index == 17
        assert instances[1].fields["span_end"].sequence_index == 23

        # We're checking this case because I changed the answer text to only have a partial
        # annotation for the last token, which happens occasionally in the training data.  We're
        # making sure we get a reasonable output in that case here.
        assert instances[3].fields["question"].tokens[:3] == ["Which", "individual", "worked"]
        assert instances[3].fields["passage"].tokens[:3] == ["In", "1882", ","]
        assert instances[3].fields["passage"].tokens[-3:] == ["Nuclear", "Astrophysics", "."]
        span_start = instances[3].fields["span_start"].sequence_index
        span_end = instances[3].fields["span_end"].sequence_index
        answer_tokens = instances[3].fields["passage"].tokens[span_start:(span_end + 1)]
        expected_answer_tokens = ["Father", "Julius", "Nieuwland"]
        assert answer_tokens == expected_answer_tokens
Ejemplo n.º 4
0
def read_data(squad_train_path, squad_dev_path, max_passage_length,
              max_question_length, min_token_count):
    """
    Read SQuAD data, and filter by passage and question length.
    """
    squad_reader = SquadReader()
    # Read SQuAD train set
    train_dataset = squad_reader.read(squad_train_path)
    logger.info("Read {} training examples".format(len(
        train_dataset.instances)))

    # Filter out examples with passage length greater than max_passage_length
    # or question length greater than max_question_length
    logger.info("Filtering out examples in train set with passage length "
                "greater than {} or question length greater than {}".format(
                    max_passage_length, max_question_length))
    train_dataset.instances = [
        instance for instance in tqdm(train_dataset.instances)
        if len(instance.fields["passage"].tokens) <= max_passage_length
        and len(instance.fields["question"].tokens) <= max_question_length
    ]
    logger.info("{} training examples remain after filtering".format(
        len(train_dataset.instances)))

    # Make a vocabulary object from the train set
    train_vocab = Vocabulary.from_dataset(train_dataset,
                                          min_count=min_token_count)

    # Index the instances with the train vocabulary.
    # This converts string tokens to numerical indices.
    train_dataset.index_instances(train_vocab)

    # Read SQuAD validation set
    logger.info("Reading SQuAD validation set at {}".format(squad_dev_path))
    validation_dataset = squad_reader.read(squad_dev_path)
    logger.info("Read {} validation examples".format(
        len(validation_dataset.instances)))

    # Filter out examples with passage length greater than max_passage_length
    # or question length greater than max_question_length
    logger.info("Filtering out examples in validation set with passage length "
                "greater than {} or question length greater than {}".format(
                    max_passage_length, max_question_length))
    validation_dataset.instances = [
        instance for instance in tqdm(validation_dataset.instances)
        if len(instance.fields["passage"].tokens) <= max_passage_length
        and len(instance.fields["question"].tokens) <= max_question_length
    ]
    logger.info("{} validation examples remain after filtering".format(
        len(validation_dataset.instances)))

    # Index the instances with the train vocabulary.
    # This converts string tokens to numerical indices.
    validation_dataset.index_instances(train_vocab)
    return train_dataset, train_vocab, validation_dataset
Ejemplo n.º 5
0
def read_squad2(dataset_path):
    from allennlp.data.dataset_readers import SquadReader
    reader = SquadReader()
    contexts, questions, answers = [], [], []
    for idx, instance in enumerate(reader.read(dataset_path)):
        try:
            answers.append(instance['metadata']['answer_texts'][0])
        except:
            # using None for non-answerable questions
            answers.append('None')
        contexts.append(" ".join(instance['metadata']['passage_tokens']))
        questions.append(" ".join(instance['metadata']['question_tokens']))
    return [contexts, answers, questions]
Ejemplo n.º 6
0
 def test_read_from_file(self):
     reader = SquadReader()
     instances = reader.read('tests/fixtures/data/squad.json').instances
     assert len(instances) == 5
     assert instances[0].fields()["question"].tokens()[:3] == ["To", "whom", "did"]
     assert instances[0].fields()["passage"].tokens()[:3] == ["Architecturally", ",", "the"]
     assert instances[0].fields()["passage"].tokens()[-3:] == ["Mary", ".", "@@STOP@@"]
     assert instances[0].fields()["span_start"].sequence_index() == 102
     assert instances[0].fields()["span_end"].sequence_index() == 105
     assert instances[1].fields()["question"].tokens()[:3] == ["What", "sits", "on"]
     assert instances[1].fields()["passage"].tokens()[:3] == ["Architecturally", ",", "the"]
     assert instances[1].fields()["passage"].tokens()[-3:] == ["Mary", ".", "@@STOP@@"]
     assert instances[1].fields()["span_start"].sequence_index() == 17
     assert instances[1].fields()["span_end"].sequence_index() == 24
Ejemplo n.º 7
0
    def test_forward(self):
        lr = 0.5
        batch_size = 16
        embedding_dim = 50
        hidden_size = 15
        dropout = 0.2

        squad_reader = SquadReader()
        # Read SQuAD train set (use the test set, since it's smaller)
        train_dataset = squad_reader.read(self.squad_test)
        vocab = Vocabulary.from_dataset(train_dataset)
        train_dataset.index_instances(vocab)

        # Random embeddings for test
        test_embed_matrix = torch.rand(vocab.get_vocab_size(), embedding_dim)
        test_attention_rnn = AttentionRNN(test_embed_matrix, hidden_size,
                                          dropout)
        try:
            optimizer = optim.Adadelta(filter(lambda p: p.requires_grad,
                                              test_attention_rnn.parameters()),
                                       lr=lr)
        except ValueError:
            # Likely there are no parameters to optimize, because
            # the code is not complete.
            pass

        iterator = BucketIterator(batch_size=batch_size,
                                  sorting_keys=[("passage", "num_tokens"),
                                                ("question", "num_tokens")])
        for batch in iterator(train_dataset, num_epochs=1):
            passage = batch["passage"]["tokens"]
            question = batch["question"]["tokens"]
            span_start = batch["span_start"]
            span_end = batch["span_end"]
            try:
                output_dict = test_attention_rnn(passage, question)
                softmax_start_logits = output_dict["softmax_start_logits"]
                softmax_end_logits = output_dict["softmax_end_logits"]
                loss = nll_loss(softmax_start_logits, span_start.view(-1))
                loss += nll_loss(softmax_end_logits, span_end.view(-1))
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            except NotImplementedError:
                # AttentionRNN.forward() not implemented yet, don't fail tests.
                pass
Ejemplo n.º 8
0
def run_demo(model, train_vocab, host, port, cuda):
    """
    Run the web demo application.
    """
    app = Flask(__name__)
    squad_reader = SquadReader()

    @app.route("/")
    def index():
        return render_template("index.html")

    @app.route("/_get_answer")
    def get_answer():
        # Take user input and convert to Instance
        user_context = request.args.get("context", "", type=str)
        user_question = request.args.get("question", "", type=str)
        input_instance = squad_reader.text_to_instance(
            question_text=user_question, passage_text=user_context)
        # Make a dataset from the instance
        dataset = Batch([input_instance])
        dataset.index_instances(train_vocab)
        batch = dataset.as_tensor_dict()
        batch = move_to_device(batch, cuda_device=0 if cuda else -1)
        # Extract relevant data from batch.
        passage = batch["passage"]["tokens"]
        question = batch["question"]["tokens"]
        metadata = batch.get("metadata", {})

        # Run data through model to get start and end logits.
        output_dict = model(passage, question)
        start_logits = output_dict["start_logits"]
        end_logits = output_dict["end_logits"]

        # Compute the best span
        best_span = get_best_span(start_logits, end_logits)

        # Get the string corresponding to the best span
        passage_str = metadata[0]['original_passage']
        offsets = metadata[0]['token_offsets']
        predicted_span = tuple(best_span[0].data.cpu().numpy())
        start_offset = offsets[predicted_span[0]][0]
        end_offset = offsets[predicted_span[1]][1]
        best_span_string = passage_str[start_offset:end_offset]

        # Return the best string back to the GUI
        return jsonify(answer=best_span_string)

    logger.info("Launching Demo...")
    app.run(port=port, host=host)
    def test_length_limit_works(self):
        # We're making sure the length of the text is correct if length limit is provided.
        reader = SquadReader(passage_length_limit=30,
                             question_length_limit=10,
                             skip_invalid_examples=True)
        instances = ensure_list(
            reader.read(AllenNlpTestCase.FIXTURES_ROOT / "data" /
                        "squad.json"))
        assert len(instances[0].fields["question"].tokens) == 10
        assert len(instances[0].fields["passage"].tokens) == 30
        # invalid examples where all the answers exceed the passage length should be skipped.
        assert len(instances) == 3

        # Length limit still works if we do not skip the invalid examples
        reader = SquadReader(passage_length_limit=30,
                             question_length_limit=10,
                             skip_invalid_examples=False)
        instances = ensure_list(
            reader.read(AllenNlpTestCase.FIXTURES_ROOT / "data" /
                        "squad.json"))
        assert len(instances[0].fields["question"].tokens) == 10
        assert len(instances[0].fields["passage"].tokens) == 30
        # invalid examples should not be skipped.
        assert len(instances) == 5

        # Make sure the answer texts does not change, so that the evaluation will not be affected
        reader_unlimited = SquadReader(passage_length_limit=30,
                                       question_length_limit=10,
                                       skip_invalid_examples=False)
        instances_unlimited = ensure_list(
            reader_unlimited.read(AllenNlpTestCase.FIXTURES_ROOT / "data" /
                                  "squad.json"))
        for instance_x, instance_y in zip(instances, instances_unlimited):
            print(instance_x.fields["metadata"]["answer_texts"])
            assert set(instance_x.fields["metadata"]["answer_texts"]) == set(
                instance_y.fields["metadata"]["answer_texts"])
Ejemplo n.º 10
0
 def test_can_build_from_params(self):
     reader = SquadReader.from_params(Params({}))
     # pylint: disable=protected-access
     assert reader._tokenizer.__class__.__name__ == 'WordTokenizer'
     assert reader._token_indexers[
         "tokens"].__class__.__name__ == 'SingleIdTokenIndexer'
    def test_can_build_from_params(self):
        reader = SquadReader.from_params(Params({}))

        assert reader._tokenizer.__class__.__name__ == "WordTokenizer"
        assert reader._token_indexers[
            "tokens"].__class__.__name__ == "SingleIdTokenIndexer"
Ejemplo n.º 12
0
        'en_core_web_sm',
        disable=['vectors', 'textcat', 'tagger', 'parser', 'ner'])

    def tokenizer(x):
        return [(token.idx, token.text) for token in spacy_en(x)
                if not token.is_space]

    if not os.path.exists('data_list.pkl'):
        with open('data/train-v2.0.txt') as f:
            data = [row for row in csv.reader(f, delimiter='\t')]
        data = [[tokenizer(x[0]), int(x[2]), int(x[3]), x[4]] for x in data]
        contexts, char_starts, char_ends, answers = zip(*data)
        with open('data_list.pkl', 'wb') as f:
            pickle.dump((contexts, char_starts, char_ends, answers), f)
    else:
        with open('data_list.pkl', 'rb') as f:
            contexts, char_starts, char_ends, answers = pickle.load(f)

    spans = get_spans(contexts, char_starts, char_ends, answers)

    data = SquadReader().read('/Users/smap11/Desktop/train-v2.0.json')

    for span, x in zip(spans, data):
        if 'span_start' in x.fields:
            if span[0] != x.fields['span_start'].sequence_index:
                print('start')
                print(span[0], x.fields['span_start'])
            if span[1] != x.fields['span_end'].sequence_index:
                print('end')
                print(span[1], x.fields['span_end'])
Ejemplo n.º 13
0
def main():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    project_root = os.path.abspath(
        os.path.realpath(
            os.path.join(os.path.dirname(os.path.realpath(__file__)))))

    parser.add_argument("--squad-train-path",
                        type=str,
                        default=os.path.join(project_root, "squad",
                                             "train_small.json"),
                        help="Path to the SQuAD training data.")
    parser.add_argument("--squad-dev-path",
                        type=str,
                        default=os.path.join(project_root, "squad",
                                             "val_small.json"),
                        help="Path to the SQuAD dev data.")
    parser.add_argument("--squad-test-path",
                        type=str,
                        default=os.path.join(project_root, "squad",
                                             "test_small.json"),
                        help="Path to the SQuAD test data.")
    parser.add_argument("--glove-path",
                        type=str,
                        default=os.path.join(project_root, "glove",
                                             "glove.6B.50d.txt"),
                        help="Path to word vectors in GloVe format.")
    parser.add_argument("--load-path",
                        type=str,
                        help=("Path to load a saved model from and "
                              "evaluate on test data. May not be "
                              "used with --save-dir."))
    parser.add_argument("--save-dir",
                        type=str,
                        help=("Path to save model checkpoints and logs. "
                              "Required if not using --load-path. "
                              "May not be used with --load-path."))
    parser.add_argument("--model-type",
                        type=str,
                        default="cbow",
                        choices=["cbow", "rnn", "attention"],
                        help="Model type to train.")
    parser.add_argument("--min-token-count",
                        type=int,
                        default=10,
                        help=("Number of times a token must be observed "
                              "in order to include it in the vocabulary."))
    parser.add_argument("--max-passage-length",
                        type=int,
                        default=150,
                        help="Maximum number of words in the passage.")
    parser.add_argument("--max-question-length",
                        type=int,
                        default=15,
                        help="Maximum number of words in the question.")
    parser.add_argument("--batch-size",
                        type=int,
                        default=64,
                        help="Batch size to use in training and evaluation.")
    parser.add_argument("--hidden-size",
                        type=int,
                        default=256,
                        help="Hidden size to use in RNN and Attention models.")
    parser.add_argument("--num-epochs",
                        type=int,
                        default=25,
                        help="Number of epochs to train for.")
    parser.add_argument("--dropout",
                        type=float,
                        default=0.2,
                        help="Dropout proportion.")
    parser.add_argument("--lr",
                        type=float,
                        default=0.5,
                        help="The learning rate to use.")
    parser.add_argument("--log-period",
                        type=int,
                        default=50,
                        help=("Update training metrics every "
                              "log-period weight updates."))
    parser.add_argument("--validation-period",
                        type=int,
                        default=500,
                        help=("Calculate metrics on validation set every "
                              "validation-period weight updates."))
    parser.add_argument("--seed",
                        type=int,
                        default=0,
                        help="Random seed to use")
    parser.add_argument("--cuda",
                        action="store_true",
                        help="Train or evaluate with GPU.")
    parser.add_argument("--demo",
                        action="store_true",
                        help="Run the interactive web demo.")
    parser.add_argument("--host",
                        type=str,
                        default="0.0.0.0",
                        help="Host to use for web demo.")
    parser.add_argument("--port",
                        type=int,
                        default=5000,
                        help="Port to use for web demo.")
    args = parser.parse_args()

    # Set the random seed manually for reproducibility.
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        if not args.cuda:
            logger.warning("\033[35mGPU available but not running with "
                           "CUDA (use --cuda to turn on.)\033[0m")
        else:
            torch.cuda.manual_seed(args.seed)

    # Load a model from checkpoint and evaluate it on test data.
    if args.load_path:
        logger.info("Loading saved model from {}".format(args.load_path))

        # If evaluating with CPU, force all tensors to CPU.
        # This lets us load models trained on the GPU and evaluate with CPU.
        saved_state_dict = torch.load(
            args.load_path,
            map_location=None if args.cuda else lambda storage, loc: storage)

        # Extract the contents of the state dictionary.
        model_type = saved_state_dict["model_type"]
        model_weights = saved_state_dict["model_weights"]
        model_init_arguments = saved_state_dict["init_arguments"]
        model_global_step = saved_state_dict["global_step"]

        # Reconstruct a model of the proper type with the init arguments.
        saved_model = MODEL_TYPES[model_type.lower()](**model_init_arguments)
        # Load the weights
        saved_model.load_state_dict(model_weights)
        # Set the global step
        saved_model.global_step = model_global_step

        logger.info("Successfully loaded model!")

        # Move model to GPU if CUDA is on.
        if args.cuda:
            saved_model = saved_model.cuda()

        # Load the serialized train_vocab.
        vocab_dir = os.path.join(os.path.dirname(args.load_path),
                                 "train_vocab")
        logger.info("Loading train vocabulary from {}".format(vocab_dir))
        train_vocab = Vocabulary.from_files(vocab_dir)
        logger.info("Successfully loaded train vocabulary!")

        if args.demo:
            # Run the demo with the loaded model.
            run_demo(saved_model, train_vocab, args.host, args.port, args.cuda)
            sys.exit(0)

        # Evaluate on the SQuAD test set.
        logger.info("Reading SQuAD test set at {}".format(
            args.squad_test_path))
        squad_reader = SquadReader()
        test_dataset = squad_reader.read(args.squad_test_path)
        logger.info("Read {} test examples".format(len(
            test_dataset.instances)))
        # Filter out examples with passage length greater than
        # max_passage_length or question length greater than
        # max_question_length
        logger.info("Filtering out examples in test set with "
                    "passage length greater than {} or question "
                    "length greater than {}".format(args.max_passage_length,
                                                    args.max_question_length))
        test_dataset.instances = [
            instance for instance in tqdm(test_dataset.instances) if
            (len(instance.fields["passage"].tokens) <= args.max_passage_length)
            and (len(instance.fields["question"].tokens) <=
                 args.max_question_length)
        ]
        logger.info("{} test examples remain after filtering".format(
            len(test_dataset.instances)))
        # Index the instances with the train vocabulary.
        # This converts string tokens to numerical indices.
        test_dataset.index_instances(train_vocab)

        # Evaluate the model on the test set.
        logger.info("Evaluating model on the test set")
        (loss, span_start_accuracy, span_end_accuracy, span_accuracy, em,
         f1) = evaluate(saved_model, test_dataset, args.batch_size, args.cuda)
        # Log metrics to console.
        logger.info("Done evaluating on test set!")
        logger.info("Test Loss: {:.4f}".format(loss))
        logger.info(
            "Test Span Start Accuracy: {:.4f}".format(span_start_accuracy))
        logger.info("Test Span End Accuracy: {:.4f}".format(span_end_accuracy))
        logger.info("Test Span Accuracy: {:.4f}".format(span_accuracy))
        logger.info("Test EM: {:.4f}".format(em))
        logger.info("Test F1: {:.4f}".format(f1))
        sys.exit(0)

    if not args.save_dir:
        raise ValueError("Must provide a value for --save-dir if training.")

    try:
        if os.path.exists(args.save_dir):
            # save directory already exists, do we really want to overwrite?
            input("Save directory {} already exists. Press <Enter> "
                  "to clear, overwrite and continue , or "
                  "<Ctrl-c> to abort.".format(args.save_dir))
            shutil.rmtree(args.save_dir)
        os.makedirs(args.save_dir)
    except KeyboardInterrupt:
        print()
        sys.exit(0)

    # Write tensorboard logs to save_dir/logs.
    log_dir = os.path.join(args.save_dir, "logs")
    os.makedirs(log_dir)

    # Read the training and validaton dataset, and get a vocabulary
    # from the train set.
    train_dataset, train_vocab, validation_dataset = read_data(
        args.squad_train_path, args.squad_dev_path, args.max_passage_length,
        args.max_question_length, args.min_token_count)

    # Save the train_vocab to a file.
    vocab_dir = os.path.join(args.save_dir, "train_vocab")
    logger.info("Saving train vocabulary to {}".format(vocab_dir))
    train_vocab.save_to_files(vocab_dir)

    # Read GloVe embeddings.
    embedding_matrix = load_embeddings(args.glove_path, train_vocab)

    # Create model of the correct type.
    if args.model_type == "cbow":
        logger.info("Building CBOW model")
        model = CBOW(embedding_matrix)
    if args.model_type == "rnn":
        logger.info("Building RNN model")
        model = RNN(embedding_matrix, args.hidden_size, args.dropout)
    if args.model_type == "attention":
        logger.info("Building attention RNN model")
        model = AttentionRNN(embedding_matrix, args.hidden_size, args.dropout)
    logger.info(model)

    # Move model to GPU if running with CUDA.
    if args.cuda:
        model = model.cuda()
    # Create the optimizer, and only update parameters where requires_grad=True
    optimizer = optim.Adadelta(filter(lambda p: p.requires_grad,
                                      model.parameters()),
                               lr=args.lr)
    # Train for the specified number of epochs.
    for i in tqdm(range(args.num_epochs), unit="epoch"):
        train_epoch(model, train_dataset, validation_dataset, args.batch_size,
                    optimizer, args.log_period, args.validation_period,
                    args.save_dir, log_dir, args.cuda)
Ejemplo n.º 14
0
 def test_can_build_from_params(self):
     reader = SquadReader.from_params(Params({}))
     # pylint: disable=protected-access
     assert reader._tokenizer.__class__.__name__ == 'WordTokenizer'
     assert reader._token_indexers["tokens"].__class__.__name__ == 'SingleIdTokenIndexer'
Ejemplo n.º 15
0
    def setUp(self):
        super(BidirectionalAttentionFlowTest, self).setUp()

        constants.GLOVE_PATH = 'tests/fixtures/glove.6B.100d.sample.txt.gz'
        reader_params = Params({
            'token_indexers': {
                'tokens': {
                    'type': 'single_id'
                },
                'token_characters': {
                    'type': 'characters'
                }
            }
        })
        dataset = SquadReader.from_params(reader_params).read(
            'tests/fixtures/data/squad.json')
        vocab = Vocabulary.from_dataset(dataset)
        self.vocab = vocab
        dataset.index_instances(vocab)
        self.dataset = dataset
        self.token_indexers = {
            'tokens': SingleIdTokenIndexer(),
            'token_characters': TokenCharactersIndexer()
        }

        self.model = BidirectionalAttentionFlow.from_params(
            self.vocab, Params({}))

        small_params = Params({
            'text_field_embedder': {
                'tokens': {
                    'type': 'embedding',
                    'pretrained_file': constants.GLOVE_PATH,
                    'trainable': False,
                    'projection_dim': 4
                },
                'token_characters': {
                    'type': 'character_encoding',
                    'embedding': {
                        'embedding_dim': 8
                    },
                    'encoder': {
                        'type': 'cnn',
                        'embedding_dim': 8,
                        'num_filters': 4,
                        'ngram_filter_sizes': [5]
                    }
                }
            },
            'phrase_layer': {
                'type': 'lstm',
                'bidirectional': True,
                'input_size': 8,
                'hidden_size': 4,
                'num_layers': 1,
            },
            'similarity_function': {
                'type': 'linear',
                'combination': 'x,y,x*y',
                'tensor_1_dim': 8,
                'tensor_2_dim': 8
            },
            'modeling_layer': {
                'type': 'lstm',
                'bidirectional': True,
                'input_size': 32,
                'hidden_size': 4,
                'num_layers': 1,
            },
            'span_end_encoder': {
                'type': 'lstm',
                'bidirectional': True,
                'input_size': 56,
                'hidden_size': 4,
                'num_layers': 1,
            },
        })
        self.small_model = BidirectionalAttentionFlow.from_params(
            self.vocab, small_params)