예제 #1
0
def get_train_valid_tests(train_path,
                          valid_path,
                          test_paths,
                          max_len,
                          src_vocab=50000,
                          tgt_vocab=50000):
    """Gets the formatted train, valid, and test data."""
    def len_filter(example):
        return len(example.src) <= max_len and len(example.tgt) <= max_len

    src = SourceField()
    tgt = TargetField(include_eos=True)
    fields = [('src', src), ('tgt', tgt)]
    train = torchtext.data.TabularDataset(path=train_path,
                                          format='tsv',
                                          fields=fields,
                                          filter_pred=len_filter)
    valid = torchtext.data.TabularDataset(path=valid_path,
                                          format='tsv',
                                          fields=fields,
                                          filter_pred=len_filter)

    tests = []
    for t in test_paths:
        tests.append(
            torchtext.data.TabularDataset(path=t,
                                          format='tsv',
                                          fields=fields,
                                          filter_pred=len_filter))

    src.build_vocab(train, max_size=src_vocab)
    tgt.build_vocab(train, max_size=tgt_vocab)

    return train, valid, tests, src, tgt
예제 #2
0
    def setUp(self):
        test_path = os.path.dirname(os.path.realpath(__file__))
        src = SourceField(batch_first=True)
        tgt = TargetField(batch_first=True)
        self.dataset = torchtext.data.TabularDataset(
            path=os.path.join(test_path, 'data/eng-fra.txt'),
            format='tsv',
            fields=[('src', src), ('tgt', tgt)],
        )
        src.build_vocab(self.dataset)
        tgt.build_vocab(self.dataset)

        self.data_iterator = torchtext.data.BucketIterator(
            dataset=self.dataset,
            batch_size=64,
            sort=False,
            sort_within_batch=True,
            sort_key=lambda x: len(x.src),
            repeat=False)

        encoder = EncoderRNN(len(src.vocab), 10, 10, 10, rnn_cell='lstm')
        decoder = DecoderRNN(len(tgt.vocab),
                             10,
                             10,
                             tgt.sos_id,
                             tgt.eos_id,
                             rnn_cell='lstm')
        self.seq2seq = Seq2seq(encoder, decoder)

        for param in self.seq2seq.parameters():
            param.data.uniform_(-0.08, 0.08)
예제 #3
0
 def setUp(self):
     test_path = os.path.dirname(os.path.realpath(__file__))
     src = SourceField()
     tgt = TargetField()
     self.dataset = torchtext.data.TabularDataset(
         path=os.path.join(test_path, 'data/eng-fra.txt'), format='tsv',
         fields=[('src', src), ('tgt', tgt)],
     )
     src.build_vocab(self.dataset)
     tgt.build_vocab(self.dataset)
예제 #4
0
 def test_targetfield_specials(self):
     test_path = os.path.dirname(os.path.realpath(__file__))
     data_path = os.path.join(test_path, 'data/eng-fra.txt')
     field = TargetField()
     train = torchtext.data.TabularDataset(
         path=data_path, format='tsv',
         fields=[('src', torchtext.data.Field()), ('trg', field)]
     )
     self.assertTrue(field.sos_id is None)
     self.assertTrue(field.eos_id is None)
     field.build_vocab(train)
     self.assertFalse(field.sos_id is None)
     self.assertFalse(field.eos_id is None)
예제 #5
0
    def setUp(self):
        test_path = os.path.dirname(os.path.realpath(__file__))
        src = SourceField(batch_first=True)
        tgt = TargetField(batch_first=True)
        self.dataset = torchtext.data.TabularDataset(
            path=os.path.join(test_path, 'data/eng-fra.txt'), format='tsv',
            fields=[('src', src), ('tgt', tgt)],
        )
        src.build_vocab(self.dataset)
        tgt.build_vocab(self.dataset)

        self.data_iterator = torchtext.data.BucketIterator(
            dataset=self.dataset, batch_size=4,
            sort=False, sort_within_batch=True,
            sort_key=lambda x: len(x.src),
            repeat=False)
예제 #6
0
    def setUp(self):
        test_path = os.path.dirname(os.path.realpath(__file__))
        src = SourceField()
        tgt = TargetField()
        self.dataset = torchtext.data.TabularDataset(
            path=os.path.join(test_path, 'data/eng-fra.txt'), format='tsv',
            fields=[('src', src), ('tgt', tgt)],
        )
        src.build_vocab(self.dataset)
        tgt.build_vocab(self.dataset)

        encoder = EncoderRNN(len(src.vocab), 10, 10, 10, rnn_cell='lstm')
        decoder = DecoderRNN(len(tgt.vocab), 10, 10, tgt.sos_id, tgt.eos_id, rnn_cell='lstm')
        self.seq2seq = Seq2seq(encoder, decoder)

        for param in self.seq2seq.parameters():
            param.data.uniform_(-0.08, 0.08)
예제 #7
0
    def setUpClass(self):
        test_path = os.path.dirname(os.path.realpath(__file__))
        src = SourceField(batch_first=True)
        trg = TargetField(batch_first=True)
        dataset = torchtext.data.TabularDataset(
            path=os.path.join(test_path, 'data/eng-fra.txt'),
            format='tsv',
            fields=[('src', src), ('trg', trg)],
        )
        src.build_vocab(dataset)
        trg.build_vocab(dataset)

        encoder = EncoderRNN(len(src.vocab), 5, 10, 10, rnn_cell='lstm')
        decoder = DecoderRNN(len(trg.vocab),
                             10,
                             10,
                             trg.sos_id,
                             trg.eos_id,
                             rnn_cell='lstm')
        seq2seq = Seq2seq(encoder, decoder)
        self.predictor = Predictor(seq2seq, src.vocab, trg.vocab)
예제 #8
0
    checkpoint_path = os.path.join(opt.output_dir, opt.load_checkpoint)
    checkpoint = Checkpoint.load(checkpoint_path)
    seq2seq = checkpoint.model

    input_vocab = checkpoint.input_vocab
    src.vocab = input_vocab

    output_vocab = checkpoint.output_vocab
    tgt.vocab = output_vocab
    tgt.eos_id = tgt.vocab.stoi[tgt.SYM_EOS]
    tgt.sos_id = tgt.vocab.stoi[tgt.SYM_SOS]

else:
    # build vocabulary
    src.build_vocab(train, max_size=opt.src_vocab)
    tgt.build_vocab(train, max_size=opt.tgt_vocab)
    input_vocab = src.vocab
    output_vocab = tgt.vocab

    # Initialize model
    hidden_size = opt.hidden_size
    decoder_hidden_size = hidden_size * 2 if opt.bidirectional else hidden_size
    encoder = EncoderRNN(len(src.vocab),
                         max_len,
                         hidden_size,
                         opt.embedding_size,
                         dropout_p=opt.dropout_p_encoder,
                         n_layers=opt.n_layers,
                         bidirectional=opt.bidirectional,
                         rnn_cell=opt.rnn_cell,
                         variable_lengths=True)
    checkpoint_path = os.path.join(opt.output_dir, opt.load_checkpoint)
    checkpoint = Checkpoint.load(checkpoint_path)
    seq2seq = checkpoint.model

    input_vocab = checkpoint.input_vocab
    src.vocab = input_vocab

    output_vocab = checkpoint.output_vocab
    tgt.vocab = output_vocab
    tgt.eos_id = tgt.vocab.stoi[tgt.SYM_EOS]
    tgt.sos_id = tgt.vocab.stoi[tgt.SYM_SOS]

else:
    # build vocabulary
    src.build_vocab(train.dataset, max_size=opt.src_vocab)
    tgt.build_vocab(train.dataset, max_size=opt.tgt_vocab)
    input_vocab = src.vocab
    output_vocab = tgt.vocab

    # Initialize model
    if opt.model == 'seq2attn':
        hidden_size = opt.hidden_size
        decoder_hidden_size = hidden_size * 2 if opt.bidirectional else hidden_size
        seq2attn_encoder = EncoderRNN(
            len(src.vocab),
            max_len,
            hidden_size,
            opt.embedding_size,
            dropout_p=opt.dropout_p_encoder,
            n_layers=opt.n_layers,
            bidirectional=opt.bidirectional,