예제 #1
0
 def test_max_sequence_length(self):
     prefix = os.path.join(self.FIXTURES, 'shards/shard0')
     dataset = SimpleLanguageModelingDatasetReader(max_sequence_length=10)
     k = -1
     for k, _ in enumerate(dataset.read(prefix)):
         pass
     self.assertEqual(k, 11)
 def test_read_multiple_sentences(self):
     prefix = os.path.join(self.FIXTURES, 'shards/shard0')
     dataset = SimpleLanguageModelingDatasetReader()
     k = -1
     for k, _ in enumerate(dataset.read(prefix)):
         pass
     self.assertEqual(k, 99)
 def test_max_sequence_length(self):
     prefix = os.path.join(self.FIXTURES, 'shards/shard0')
     dataset = SimpleLanguageModelingDatasetReader(max_sequence_length=10)
     k = -1
     for k, _ in enumerate(dataset.read(prefix)):
         pass
     self.assertEqual(k, 11)
예제 #4
0
 def test_read_multiple_sentences(self):
     prefix = os.path.join(self.FIXTURES, 'shards/shard0')
     dataset = SimpleLanguageModelingDatasetReader()
     k = -1
     for k, _ in enumerate(dataset.read(prefix)):
         pass
     self.assertEqual(k, 99)
예제 #5
0
    def test_text_to_instance(self):
        dataset = SimpleLanguageModelingDatasetReader()

        instance = dataset.text_to_instance('The only sentence.')
        text = [
            t.text for t in cast(TextField, instance.fields["source"]).tokens
        ]
        self.assertEqual(text, ["The", "only", "sentence", "."])
 def test_max_sequence_length(self):
     prefix = os.path.join(self.FIXTURES, "shards/shard0")
     dataset = SimpleLanguageModelingDatasetReader(
         max_sequence_length=10, start_tokens=["<S>"], end_tokens=["</S>"]
     )
     k = -1
     for k, _ in enumerate(dataset.read(prefix)):
         pass
     self.assertEqual(k, 7)
 def test_read_single_sentence(self):
     prefix = os.path.join(self.FIXTURES, 'single_sentence.txt')
     dataset = SimpleLanguageModelingDatasetReader()
     with open(prefix, 'r') as fin:
         sentence = fin.read().strip()
     expected_batch = dataset.text_to_instance(sentence)
     batch = None
     for batch in dataset.read(prefix):
         break
     self.assertEqual(sorted(list(expected_batch.fields.keys())),
                      sorted(list(batch.fields.keys())))
     for k in expected_batch.fields.keys():
         self.assertTrue(str(expected_batch.fields[k]) == str(batch.fields[k]))
예제 #8
0
 def test_read_single_sentence(self):
     prefix = os.path.join(self.FIXTURES, 'single_sentence.txt')
     dataset = SimpleLanguageModelingDatasetReader()
     with open(prefix, 'r') as fin:
         sentence = fin.read().strip()
     expected_batch = dataset.text_to_instance(sentence)
     batch = None
     for batch in dataset.read(prefix):
         break
     self.assertEqual(sorted(list(expected_batch.fields.keys())),
                      sorted(list(batch.fields.keys())))
     for k in expected_batch.fields.keys():
         self.assertTrue(
             str(expected_batch.fields[k]) == str(batch.fields[k]))
def sentence_perplexitys(model: Model,
                         dataset_reader: SimpleLanguageModelingDatasetReader,
                         sentences: List[str],
                         tokeniser: Callable[[str], List[str]],
                         pre_tokenise: bool = True) -> List[float]:
    '''
    Given a language model, a dataset reader to convert a sentence into the 
    sutible form for the language model, a list of sentences, a tokeniser 
    that is only required if the pre_tokenise is True. Will return a list of 
    perplexity scores corresponding to each sentence in the list of sentences 
    given.

    If pre_tokenise is True then each sentence is tokenised with the given 
    tokeniser and then joined back together on whitespace. This is only to 
    ensure that the model is given sentences that have been tokenised in a 
    similar way to how the sentences given to the model in training were.
    '''
    sentence_instances = []
    for sentence in sentences:
        if pre_tokenise:
            sentence_tokens = tokeniser(sentence)
            sentence = ' '.join(sentence_tokens)
        sentence_instances.append(dataset_reader.text_to_instance(sentence))

    results = model.forward_on_instances(sentence_instances)
    result_perplexitys = [math.exp(result['loss']) for result in results]
    return result_perplexitys
    def test_text_to_instance(self):
        dataset = SimpleLanguageModelingDatasetReader(start_tokens=["<S>"], end_tokens=["</S>"])

        instance = dataset.text_to_instance('The only sentence.')
        text = [t.text for t in cast(TextField, instance.fields["source"]).tokens]
        self.assertEqual(text, ["<S>", "The", "only", "sentence", ".", "</S>"])
예제 #11
0
def train(model_dir):

    # prepare data
    #reader = CoqaDatasetReader()
    #reader = CoqaDatasetReader(tokenizer=lambda x: WordTokenizer().tokenize(text=x))
    #reader = LanguageModelingReader(tokenizer=WordTokenizer(word_splitter=SpacyWordSplitter(language='en_core_web_sm')))
    reader = SimpleLanguageModelingDatasetReader(tokenizer=WordTokenizer(
        word_splitter=SpacyWordSplitter(language='en_core_web_sm')))
    train_dataset = reader.read(
        cached_path(
            '/mnt/DATA/ML/data/corpora/QA/CoQA/stories_only/coqa-train-v1.0_extract100.json'
        ))
    validation_dataset = reader.read(
        cached_path(
            '/mnt/DATA/ML/data/corpora/QA/CoQA/stories_only/coqa-dev-v1.0.json'
        ))

    vocab = None
    model_fn = os.path.join(model_dir, 'model.th')
    vocab_fn = os.path.join(model_dir, 'vocab')
    if os.path.exists(model_dir):
        if os.path.exists(vocab_fn):
            logging.info('load vocab from: %s...' % vocab_fn)
            vocab = Vocabulary.from_files(vocab_fn)
    else:
        os.makedirs(model_dir)
    if vocab is None:
        #vocab = Vocabulary.from_instances(train_dataset + validation_dataset)
        vocab = Vocabulary.from_instances(train_dataset)
        #TODO: re-add!
        #vocab.extend_from_instances(validation_dataset)
        logging.info('save vocab to: %s...' % vocab_fn)
        vocab.save_to_files(vocab_fn)
    logging.info('data prepared')

    model = create_model(vocab)

    if os.path.exists(model_fn):
        logging.info('load model wheights from: %s...' % model_fn)
        with open(model_fn, 'rb') as f:
            model.load_state_dict(torch.load(f))
    logging.info('model prepared')

    # prepare training
    # optimizer = optim.SGD(model.parameters(), lr=0.1)
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    iterator = BasicIterator(batch_size=32)
    iterator.index_with(vocab)
    trainer = Trainer(model=model,
                      optimizer=optimizer,
                      iterator=iterator,
                      train_dataset=train_dataset,
                      validation_dataset=validation_dataset,
                      patience=10,
                      num_epochs=10)
    logging.info('training prepared')

    trainer.train()

    logging.info('save model to: %s...' % model_fn)
    with open(model_fn, 'wb') as f:
        torch.save(model.state_dict(), f)