Beispiel #1
0
 def test_dropout_WITH_PROB_ZERO(self):
     rnn = EncoderRNN(self.vocab_size, 10, 50, 16, dropout_p=0)
     for param in rnn.parameters():
         param.data.uniform_(-1, 1)
     output1, _ = rnn(self.input_var, self.lengths)
     output2, _ = rnn(self.input_var, self.lengths)
     self.assertTrue(torch.equal(output1.data, output2.data))
Beispiel #2
0
    def test_input_dropout_WITH_NON_ZERO_PROB(self):
        rnn = EncoderRNN(self.vocab_size, 10, 50, 16, input_dropout_p=0.5)
        for param in rnn.parameters():
            param.data.uniform_(-1, 1)

        equal = True
        for _ in range(50):
            output1, _ = rnn(self.input_var, self.lengths)
            output2, _ = rnn(self.input_var, self.lengths)
            if not torch.equal(output1.data, output2.data):
                equal = False
                break
        self.assertFalse(equal)
Beispiel #3
0
    def setUp(self):
        test_path = os.path.dirname(os.path.realpath(__file__))
        src = SourceField(batch_first=True)
        tgt = TargetField(batch_first=True)
        self.dataset = torchtext.data.TabularDataset(
            path=os.path.join(test_path, 'data/eng-fra.txt'),
            format='tsv',
            fields=[('src', src), ('tgt', tgt)],
        )
        src.build_vocab(self.dataset)
        tgt.build_vocab(self.dataset)

        self.data_iterator = torchtext.data.BucketIterator(
            dataset=self.dataset,
            batch_size=64,
            sort=False,
            sort_within_batch=True,
            sort_key=lambda x: len(x.src),
            repeat=False)

        encoder = EncoderRNN(len(src.vocab), 10, 10, 10, rnn_cell='lstm')
        decoder = DecoderRNN(len(tgt.vocab),
                             10,
                             10,
                             tgt.sos_id,
                             tgt.eos_id,
                             rnn_cell='lstm')
        self.seq2seq = Seq2seq(encoder, decoder)

        for param in self.seq2seq.parameters():
            param.data.uniform_(-0.08, 0.08)
Beispiel #4
0
    def test_dropout_WITH_NON_ZERO_PROB(self):
        # It's critical to set n_layer=2 here since dropout won't work
        # when the RNN only has one layer according to pytorch's doc
        rnn = EncoderRNN(self.vocab_size,
                         10,
                         50,
                         16,
                         n_layers=2,
                         dropout_p=0.5)
        for param in rnn.parameters():
            param.data.uniform_(-1, 1)

        equal = True
        for _ in range(50):
            output1, _ = rnn(self.input_var, self.lengths)
            output2, _ = rnn(self.input_var, self.lengths)
            if not torch.equal(output1.data, output2.data):
                equal = False
                break
        self.assertFalse(equal)
Beispiel #5
0
    def setUp(self):
        test_path = os.path.dirname(os.path.realpath(__file__))
        src = SourceField()
        tgt = TargetField()
        self.dataset = torchtext.data.TabularDataset(
            path=os.path.join(test_path, 'data/eng-fra.txt'), format='tsv',
            fields=[('src', src), ('tgt', tgt)],
        )
        src.build_vocab(self.dataset)
        tgt.build_vocab(self.dataset)

        encoder = EncoderRNN(len(src.vocab), 10, 10, 10, rnn_cell='lstm')
        decoder = DecoderRNN(len(tgt.vocab), 10, 10, tgt.sos_id, tgt.eos_id, rnn_cell='lstm')
        self.seq2seq = Seq2seq(encoder, decoder)

        for param in self.seq2seq.parameters():
            param.data.uniform_(-0.08, 0.08)
Beispiel #6
0
    def setUpClass(self):
        test_path = os.path.dirname(os.path.realpath(__file__))
        src = SourceField(batch_first=True)
        trg = TargetField(batch_first=True)
        dataset = torchtext.data.TabularDataset(
            path=os.path.join(test_path, 'data/eng-fra.txt'),
            format='tsv',
            fields=[('src', src), ('trg', trg)],
        )
        src.build_vocab(dataset)
        trg.build_vocab(dataset)

        encoder = EncoderRNN(len(src.vocab), 5, 10, 10, rnn_cell='lstm')
        decoder = DecoderRNN(len(trg.vocab),
                             10,
                             10,
                             trg.sos_id,
                             trg.eos_id,
                             rnn_cell='lstm')
        seq2seq = Seq2seq(encoder, decoder)
        self.predictor = Predictor(seq2seq, src.vocab, trg.vocab)