def test_lstm_cell_decoder_net_forward_without_bidirectionality(self):
        decoder_inout_dim = 10
        lstm_decoder_net = LstmCellDecoderNet(
                decoding_dim=decoder_inout_dim,
                target_embedding_dim=decoder_inout_dim,
                attention=DotProductAttention(),
                bidirectional_input=False)
        batch_size = 5
        time_steps = 10
        encoded_state = torch.rand(batch_size, time_steps, decoder_inout_dim)
        source_mask = torch.ones(batch_size, time_steps)
        source_mask[0, 7:] = 0
        source_mask[1, 5:] = 0
        encoder_out = {
                "source_mask": source_mask,
                "encoder_outputs": encoded_state
        }
        prev_step_prediction_embeded = torch.rand(batch_size, 1, decoder_inout_dim)
        prev_state = lstm_decoder_net.init_decoder_state(encoder_out)

        next_state, decoded_vec = lstm_decoder_net(prev_state, encoded_state,
                                                   source_mask, prev_step_prediction_embeded)
        assert list(next_state["decoder_hidden"].shape) == [batch_size, decoder_inout_dim]
        assert list(next_state["decoder_context"].shape) == [batch_size, decoder_inout_dim]
        assert list(decoded_vec.shape) == [batch_size, decoder_inout_dim]
Exemple #2
0
 def test_lstm_cell_decoder_net_init(self):
     decoder_inout_dim = 10
     lstm_decoder_net = LstmCellDecoderNet(
         decoding_dim=decoder_inout_dim,
         target_embedding_dim=decoder_inout_dim,
         attention=DotProductAttention(),
         bidirectional_input=False,
     )
     batch_size = 5
     time_steps = 10
     encoded_state = torch.rand(batch_size, time_steps, decoder_inout_dim)
     source_mask = torch.ones(batch_size, time_steps)
     source_mask[0, 7:] = 0
     source_mask[1, 5:] = 0
     encoder_out = {"source_mask": source_mask, "encoder_outputs": encoded_state}
     decoder_init_state = lstm_decoder_net.init_decoder_state(encoder_out)
     assert list(decoder_init_state["decoder_hidden"].shape) == [batch_size, decoder_inout_dim]
     assert list(decoder_init_state["decoder_context"].shape) == [batch_size, decoder_inout_dim]
Exemple #3
0
 def test_model(self):
     self.setUp()
     embedding = Embedding(
         num_embeddings=self.vocab.get_vocab_size('tokens'),
         embedding_dim=EMBEDDING_DIM)
     embedder = BasicTextFieldEmbedder({'tokens': embedding})
     encoder = PytorchSeq2SeqWrapper(
         DenoisingEncoder(bidirectional=True,
                          num_layers=2,
                          input_size=EMBEDDING_DIM,
                          hidden_size=HIDDEN_DIM,
                          use_bridge=True))
     decoder_net = LstmCellDecoderNet(decoding_dim=HIDDEN_DIM,
                                      target_embedding_dim=EMBEDDING_DIM)
     decoder = AutoRegressiveSeqDecoder(max_decoding_steps=100,
                                        target_namespace='tokens',
                                        target_embedder=embedding,
                                        beam_size=5,
                                        decoder_net=decoder_net,
                                        vocab=self.vocab)
     model = SalienceSeq2Seq(encoder=encoder,
                             decoder=decoder,
                             vocab=self.vocab,
                             source_text_embedder=embedder)
     optimizer = optim.Adam(model.parameters(), lr=0.1)
     iterator = BucketIterator(batch_size=4,
                               sorting_keys=[("source_tokens", "num_tokens")
                                             ])
     iterator.index_with(self.vocab)
     if torch.cuda.is_available():
         cuda_device = 0
         model = model.cuda(cuda_device)
     else:
         cuda_device = -1
     trainer = Trainer(model=model,
                       optimizer=optimizer,
                       train_dataset=self.train_dataset,
                       validation_dataset=self.val_dataset,
                       iterator=iterator,
                       num_epochs=2,
                       cuda_device=cuda_device)
     trainer.train()
 if os.path.exists(vocab_path):
     vocab = Vocabulary.from_files(vocab_path)
 else:
     vocab = Vocabulary.from_instances(train_dataset, max_vocab_size=80000)
     vocab.save_to_files(vocab_path)
 embedding = Embedding(num_embeddings=vocab.get_vocab_size('train'),
                       vocab_namespace='train',
                       embedding_dim=128,
                       trainable=True)
 embedder = BasicTextFieldEmbedder({'tokens': embedding})
 encoder = PytorchSeq2SeqWrapper(
     torch.nn.LSTM(input_size=128,
                   hidden_size=128,
                   num_layers=1,
                   batch_first=True))
 decoder_net = LstmCellDecoderNet(decoding_dim=128,
                                  target_embedding_dim=128)
 decoder = AutoRegressiveSeqDecoder(max_decoding_steps=100,
                                    target_namespace='train',
                                    target_embedder=embedding,
                                    beam_size=5,
                                    decoder_net=decoder_net,
                                    vocab=vocab)
 model = Seq2SeqModel(encoder=encoder,
                      decoder=decoder,
                      vocab=vocab,
                      src_embedder=embedder)
 optimizer = optim.SGD(model.parameters(), lr=0.1)
 iterator = BucketIterator(batch_size=8,
                           sorting_keys=[("source_tokens", "num_tokens")])
 iterator.index_with(vocab)
 if torch.cuda.is_available():