示例#1
0
def train_model(
    train_source,
    train_target,
    dev_source,
    dev_target,
    experiment_directory,
    resume=False,
):
    # Prepare dataset
    train = Seq2SeqDataset.from_file(train_source, train_target)
    train.build_vocab(300, 6000)
    dev = Seq2SeqDataset.from_file(
        dev_source,
        dev_target,
        share_fields_from=train,
    )
    input_vocab = train.src_field.vocab
    output_vocab = train.tgt_field.vocab

    # Prepare loss
    weight = torch.ones(len(output_vocab))
    pad = output_vocab.stoi[train.tgt_field.pad_token]
    loss = Perplexity(weight, pad)
    if torch.cuda.is_available():
        loss.cuda()

    seq2seq = None
    optimizer = None
    if not resume:
        seq2seq, optimizer, scheduler = initialize_model(
            train, input_vocab, output_vocab)

    # Train
    trainer = SupervisedTrainer(
        loss=loss,
        batch_size=32,
        checkpoint_every=50,
        print_every=10,
        experiment_directory=experiment_directory,
    )
    start = time.clock()
    try:
        seq2seq = trainer.train(
            seq2seq,
            train,
            n_epochs=10,
            dev_data=dev,
            optimizer=optimizer,
            teacher_forcing_ratio=0.5,
            resume=resume,
        )
    # Capture ^C
    except KeyboardInterrupt:
        pass
    end = time.clock() - start
    logging.info('Training time: %.2fs', end)

    return seq2seq, input_vocab, output_vocab
示例#2
0
    def predict(self, src_seq):
        """ Make prediction given `src_seq` as input.

        Args:
            src_seq (list): list of input tokens in source language

        Returns:
            tgt_seq (list): list of output tokens in target language as predicted
            by the pre-trained model
        """
        with torch.no_grad():
            src_id_seq = Variable(
                torch.LongTensor([self.src_vocab.stoi[tok]
                                  for tok in src_seq])).view(1, -1)
            if torch.cuda.is_available():
                src_id_seq = src_id_seq.cuda()

            dataset = Seq2SeqDataset.from_list(' '.join(src_seq))
            dataset.vocab = self.src_vocab
            batch = torchtext.data.Batch.fromvars(dataset,
                                                  1,
                                                  src=(src_id_seq,
                                                       [len(src_seq)]),
                                                  tgt=None)

            _, _, other = self.model(batch)

        length = other['length'][0]

        tgt_id_seq = [other['sequence'][di][0].data[0] for di in range(length)]
        tgt_seq = [self.tgt_vocab.itos[tok] for tok in tgt_id_seq]
        return tgt_seq
示例#3
0
    def test_init_FROM_LIST(self):
        src_list = [['1', '2', '3'], ['4', '5', '6', '7']]
        dataset = Seq2SeqDataset.from_list(src_list, dynamic=False)

        self.assertEqual(len(dataset), 2)

        tmp_file = open('temp', 'w')
        for seq in src_list:
            tmp_file.write(' '.join(seq) + "\n")
        tmp_file.close()
        from_file = Seq2SeqDataset.from_file('temp', dynamic=False)

        self.assertEqual(len(dataset.examples), len(from_file.examples))
        for l, f in zip(dataset.examples, from_file.examples):
            self.assertEqual(l.src, f.src)
        os.remove('temp')
示例#4
0
    def test_indices(self):
        dataset = Seq2SeqDataset.from_file(self.src_path, self.tgt_path, dynamic=False)
        dataset.build_vocab(1000, 1000)
        batch_size = 25

        generator = torchtext.data.BucketIterator(dataset, batch_size, device=-1)
        batch = next(generator.__iter__())
        self.assertTrue(hasattr(batch, 'index'))
示例#5
0
 def test_init_SRC_AND_TGT(self):
     dataset = Seq2SeqDataset.from_file(self.src_path,
                                        self.tgt_path,
                                        dynamic=False)
     self.assertEqual(len(dataset.fields), 3)
     self.assertEqual(len(dataset), 100)
     ex = dataset.examples[0]
     self.assertTrue(len(getattr(ex, seq2seq.tgt_field_name)) > 2)
示例#6
0
 def test_dynamic(self):
     dataset = Seq2SeqDataset.from_file(self.src_path, self.tgt_path, dynamic=True)
     self.assertTrue('src_index' in dataset.fields)
     for i, ex in enumerate(dataset.examples):
         idx = ex.index
         self.assertEqual(i, idx)
         src_vocab = dataset.dynamic_vocab[i]
         for tok, tok_id in zip(ex.src, ex.src_index):
             self.assertEqual(src_vocab.stoi[tok], tok_id)
示例#7
0
    def predict(self, src_seq):
        """ Make prediction given `src_seq` as input.

        Args:
            src_seq (list): list of input tokens in source language

        Returns:
            tgt_seq (list): list of output tokens in target language as predicted
            by the pre-trained model
        """
        with torch.no_grad():
            src_id_seq = Variable(
                torch.LongTensor([self.src_vocab.stoi[tok]
                                  for tok in src_seq])).view(1, -1)
            if torch.cuda.is_available():
                src_id_seq = src_id_seq.cuda()

            dataset = Seq2SeqDataset.from_list(' '.join(src_seq))
            dataset.vocab = self.src_vocab
            batch = torchtext.data.Batch.fromvars(dataset,
                                                  1,
                                                  src=(src_id_seq,
                                                       [len(src_seq)]),
                                                  tgt=None)

            _, _, other = self.model(batch)

        length = other['length'][0]

        tgt_id_seq = [other['sequence'][di][0].data[0] for di in range(length)]
        tgt_seq = [self.tgt_vocab.itos[tok] for tok in tgt_id_seq]

        #         p =other['topk_sequence']

        #         print(p)
        #         gen_sen=[]
        #         length = []
        #         import random
        #         for idx in range(len(other['topk_length'][0])):
        # #             idx=random.randint(0,len(other['topk_length'])-1)
        #             gen_sen=[seq[:,idx,:] for seq in p]

        #             length=[seq_len[idx] for seq_len in other['topk_length']][0]
        #             tgt_id_seq = [gen_sen[di][0].data[0] for di in range(length)]
        #             tgt_seq = [self.tgt_vocab.itos[tok] for tok in tgt_id_seq]
        #             print(tgt_seq)
        return tgt_seq
示例#8
0
 def setUpClass(cls):
     num_class = 5
     cls.num_class = 5
     batch_size = 5
     length = 7
     cls.outputs = [
         F.softmax(Variable(torch.randn(batch_size, num_class)), dim=-1)
         for _ in range(length)
     ]
     targets = [
         random.randint(0, num_class - 1)
         for _ in range(batch_size * (length + 1))
     ]
     targets_list = [str(x) for x in targets]
     sources = ['0'] * len(targets)
     dataset = Seq2SeqDataset.from_list(sources, targets_list)
     dataset.build_vocab(5, 5)
     cls.targets = Variable(torch.LongTensor(targets)).view(
         batch_size, length + 1)
     cls.batch = torchtext.data.Batch.fromvars(dataset,
                                               batch_size,
                                               tgt=cls.targets)
示例#9
0
                    help='The name of the checkpoint to load, usually an encoded time string')
parser.add_argument('--resume', action='store_true', dest='resume',
                    default=False,
                    help='Indicates if training has to be resumed from the latest checkpoint')
parser.add_argument('--log-level', dest='log_level',
                    default='info',
                    help='Logging level.')

opt = parser.parse_args()

LOG_FORMAT = '%(asctime)s %(name)-12s %(levelname)-8s %(message)s'
logging.basicConfig(format=LOG_FORMAT, level=getattr(logging, opt.log_level.upper()))
logging.info(opt)

# Prepare dataset
train = Seq2SeqDataset.from_file(opt.train_src, opt.train_tgt)
train.build_vocab(50000, 50000)
dev = Seq2SeqDataset.from_file(opt.dev_src, opt.dev_tgt, share_fields_from=train)
input_vocab = train.src_field.vocab
output_vocab = train.tgt_field.vocab

# Prepare loss
weight = torch.ones(len(output_vocab))
pad = output_vocab.stoi[train.tgt_field.pad_token]
loss = Perplexity(weight, pad)
if torch.cuda.is_available():
    loss.cuda()

if opt.load_checkpoint is not None:
    logging.info("loading checkpoint from {}".format(os.path.join(opt.expt_dir, Checkpoint.CHECKPOINT_DIR_NAME, opt.load_checkpoint)))
    checkpoint_path = os.path.join(opt.expt_dir, Checkpoint.CHECKPOINT_DIR_NAME, opt.load_checkpoint)
示例#10
0
 def test_init_ONLY_SRC(self):
     dataset = Seq2SeqDataset.from_file(self.src_path, dynamic=False)
     self.assertEqual(len(dataset.fields), 2)
     self.assertEqual(len(dataset), 100)
     self.assertTrue(hasattr(dataset.examples[0], seq2seq.src_field_name))