예제 #1
0
파일: train.py 프로젝트: wang-h/pynmt
def main():
    # Load config.
    config = Config("train", training=True)
    trace(config)
    torch.backends.cudnn.benchmark = True

    # Load train dataset.
    train_data = load_dataset(
        config.train_dataset,
        config.train_batch_size,
        config, prefix="Training:")
    
    # Load valid dataset.
    valid_data = load_dataset(
        config.valid_dataset,
        config.valid_batch_size,
        config, prefix="Validation:")

    # Build model.
    vocab = train_data.get_vocab()
    model = model_factory(config, 
                config.checkpoint, *vocab)
    if config.verbose: trace(model)

    # start training
    trg_vocab = train_data.trg_vocab
    padding_idx = trg_vocab.padding_idx
    trainer = Trainer(model, trg_vocab, padding_idx, config)
    start_epoch = 1
    for epoch in range(start_epoch, config.epochs + 1):
        trainer.train(epoch, config.epochs,
                      train_data, valid_data,
                      train_data.num_batches)
    dump_checkpoint(trainer.model, config.save_model)
예제 #2
0
def report_rouge(reference_corpus, translation_corpus):

    scores = rouge([" ".join(x) for x in translation_corpus],
                   [" ".join(x) for x in reference_corpus])

    trace("ROUGE-1:%.2f, ROUGE-2:%.2f" %
          (scores["rouge_1/f_score"] * 100, scores["rouge_2/f_score"] * 100))
예제 #3
0
def model_factory(config, checkpoint, *vocab):
    # Make embedding.


    src_vocab, trg_vocab, src_embeddings, trg_embeddings = \
        make_embeddings(config, *vocab)

    if config.system == "RNN":
        model = RNNModel(src_embeddings, trg_embeddings, trg_vocab.vocab_size,
                         config)

    elif config.system == "Transformer":
        model = TransformerModel(src_embeddings, trg_embeddings,
                                 trg_vocab.vocab_size, trg_vocab.padding_idx,
                                 config)
    if checkpoint:
        trace("Loading model parameters from checkpoint: %s." %
              str(checkpoint))
        cp = CheckPoint(checkpoint)
        model.load_state_dict(cp.state_dict['model'], strict=False)

    if config.training:
        model.train()
    else:
        model.eval()

    if config.use_gpu is not None:
        model.cuda()
    else:
        model.cpu()

    return model
예제 #4
0
 def update_lr(self, ppl, epoch):
     if self.start_decay_at is not None and epoch > self.start_decay_at:
         self.start_decay = True
     if self.last_ppl is not None and ppl > self.last_ppl:
         self.start_decay = True
     
     if self.start_decay:
         self.lr = self.lr * self.lr_decay_rate
         trace("Decaying learning rate to %g" % self.lr)
     self.last_ppl = ppl
     self.optimizer.param_groups[0]['lr'] = self.lr
     
예제 #5
0
파일: CheckPoint.py 프로젝트: wang-h/pynmt
 def load(self, path):
     abspath = os.path.abspath(path)
     if os.path.isfile(abspath):
         saved = torch.load(path)
         if "model" in saved:
             params = saved['model']
         else:
             params = saved
         return params
     else:
         trace("#ERROR! checkpoint file does not exist !")
         sys.exit()
예제 #6
0
def main():
    """main function for checkpoint ensemble."""
    config = Config("ensemble", training=True)
    trace(config)
    torch.backends.cudnn.benchmark = True

    train_data = load_dataset(config.train_dataset,
                              config.train_batch_size,
                              config,
                              prefix="Training:")

    # Build model.
    vocab = train_data.get_vocab()
    model = model_factory(config, config.checkpoint, *vocab)
    cp = CheckPoint(config.checkpoint)
    model.load_state_dict(cp.state_dict['model'])
    dump_checkpoint(model, config.save_model, ".ensemble")
예제 #7
0
파일: train.py 프로젝트: wang-h/pynmt
def load_dataset(dataset, batch_size, config, prefix):
    # Load training/validation dataset.
    train_src = os.path.join(
        config.data_path, dataset + "." + config.src_lang)
    train_trg = os.path.join(
        config.data_path, dataset + "." + config.trg_lang)
    train_data = DataBatchIterator(
        train_src, train_trg,
        share_vocab=config.share_vocab,
        training=config.training,
        shuffle=config.shuffle_data,
        batch_size=batch_size,
        max_length=config.max_seq_len,
        vocab=config.save_vocab,
        mini_batch_sort_order=config.mini_batch_sort_order)
    trace(prefix, train_data)
    return train_data
예제 #8
0
파일: translate.py 프로젝트: wang-h/pynmt
def main():
    config = Config("translate", training=False)
    if config.verbose: trace(config)
    torch.backends.cudnn.benchmark = True

    test_data = load_dataset(config.test_dataset,
                             config.test_batch_size,
                             config,
                             prefix="Translate:")

    # Build model.
    vocab = test_data.get_vocab()
    pred_file = codecs.open(config.output + ".pred.txt", 'w', 'utf-8')

    model = model_factory(config, config.checkpoint, *vocab)
    translator = BatchTranslator(model, config, test_data.src_vocab,
                                 test_data.trg_vocab)

    # Statistics
    counter = count(1)
    pred_list = []
    gold_list = []
    for batch in tqdm(iter(test_data), total=test_data.num_batches):

        batch_trans = translator.translate(batch)

        for trans in batch_trans:
            if config.verbose:
                sent_number = next(counter)
                trace(trans.pprint(sent_number))

            if config.plot_attn:
                plot_attn(trans.src, trans.preds[0], trans.attns[0].cpu())

            pred_file.write(" ".join(trans.preds[0]) + "\n")
            pred_list.append(trans.preds[0])
            gold_list.append(trans.gold)
    report_bleu(gold_list, pred_list)
    report_rouge(gold_list, pred_list)
예제 #9
0
파일: Trainer.py 프로젝트: wang-h/pynmt
    def train(self, current_epoch, epochs, train_data, valid_data,
              num_batches):
        """ Train next epoch.
        Args:
            train_data (BatchDataIterator): training dataset iterator
            valid_data (BatchDataIterator): validation dataset iterator
            epoch (int): the epoch number
            num_batches (int): the batch number
        Returns:
            stats (Statistics): epoch loss statistics
        """
        self.model.train()

        if self.stop:
            return
        header = '-' * 30 + "Epoch [%d]" + '-' * 30
        trace(header % current_epoch)
        train_stats = Statistics()
        num_batches = train_data.num_batches

        batch_cache = []
        for idx, batch in enumerate(iter(train_data), 1):
            batch_cache.append(batch)
            if len(batch_cache) == self.accum_grad_count or idx == num_batches:
                stats = self.train_each_batch(batch_cache, current_epoch, idx,
                                              num_batches)
                batch_cache = []
                if idx == train_data.num_batches:
                    train_stats.update(stats)
                if idx % self.report_every == 0 or idx == num_batches:
                    trace(
                        stats.report(current_epoch, idx, num_batches,
                                     self.optim.lr))
            if idx % (self.report_every * 10) == 0 and self.early_stop:
                valid_stats = self.validate(valid_data)
                trace("Validation: " + valid_stats.report(
                    current_epoch, idx, num_batches, self.optim.lr))
                if self.early_stop(valid_stats.ppl()):
                    self.stop = True
                    break
        valid_stats = self.validate(valid_data)
        trace(str(valid_stats))
        suffix = ".acc{0:.2f}.ppl{1:.2f}.e{2:d}".format(
            valid_stats.accuracy(), valid_stats.ppl(), current_epoch)
        self.optim.update_lr(valid_stats.ppl(), current_epoch)
        dump_checkpoint(self.model, self.save_model, suffix)
예제 #10
0
def make_embeddings(config, *vocab):
    """
    Make an Embeddings instance.
    Args:
        vocab (Vocab): words dictionary.
        config: global configuration settings.
    """

    if len(vocab) == 2:
        trace("Making independent embeddings ...")
        src_vocab, trg_vocab = vocab
        padding_idx = src_vocab.stoi[PAD_WORD]
        src_embeddings = nn.Embedding(src_vocab.vocab_size,
                                      config.src_embed_dim,
                                      padding_idx=padding_idx,
                                      max_norm=None,
                                      norm_type=2,
                                      scale_grad_by_freq=False,
                                      sparse=config.sparse_embeddings)
        trg_embeddings = nn.Embedding(trg_vocab.vocab_size,
                                      config.src_embed_dim,
                                      padding_idx=padding_idx,
                                      max_norm=None,
                                      norm_type=2,
                                      scale_grad_by_freq=False,
                                      sparse=config.sparse_embeddings)
    else:

        assert config.trg_embed_dim == config.src_embed_dim
        src_vocab = trg_vocab = vocab[0]
        padding_idx = trg_vocab.padding_idx
        src_embeddings = nn.Embedding(src_vocab.vocab_size,
                                      config.src_embed_dim,
                                      padding_idx=padding_idx,
                                      max_norm=None,
                                      norm_type=2,
                                      scale_grad_by_freq=False,
                                      sparse=config.sparse_embeddings)
        if config.share_embedding:
            trace("Making shared embeddings ...")
            trg_embeddings = src_embeddings
        else:
            trace("Making independent embeddings ...")
            trg_embeddings = nn.Embedding(trg_vocab.vocab_size,
                                          config.trg_embed_dim,
                                          padding_idx=padding_idx,
                                          max_norm=None,
                                          norm_type=2,
                                          scale_grad_by_freq=False,
                                          sparse=config.sparse_embeddings)
    return src_vocab, trg_vocab, src_embeddings, trg_embeddings
예제 #11
0
def report_bleu(reference_corpus, translation_corpus):

    bleu, precions, bp, ratio, trans_length, ref_length =\
        compute_bleu([[x] for x in reference_corpus], translation_corpus)
    trace("BLEU: %.2f [%.2f/%.2f/%.2f/%.2f] Pred_len:%d, Ref_len:%d" %
          (bleu * 100, *precions, trans_length, ref_length))
예제 #12
0
파일: Model.py 프로젝트: wang-h/pynmt
 def param_init(self):
     trace("Initializing model parameters.")
     for p in self.parameters():
         p.data.uniform_(-0.1, 0.1)
예제 #13
0
 def check_config_exist(self):
     if not os.path.isfile(self.config_file):
         trace("""# Cannot find the configuration file. 
             {} does not exist! Please check.""".format(self.config_file))
         sys.exit(1)
예제 #14
0
파일: DataLoader.py 프로젝트: wang-h/pynmt
 def load_vocab(self, path):
     trace("Loading vocabulary ...")
     if self.share_vocab:
         self.trg_vocab = self.src_vocab = torch.load(path)
     else:
         self.src_vocab, self.trg_vocab = torch.load(path)
예제 #15
0
파일: DataLoader.py 프로젝트: wang-h/pynmt
 def make_vocab(self):
     trace("Building vocabulary ...")
     self.src_vocab.make_vocab(map(lambda x: x[1][0], self.examples))
     self.trg_vocab.make_vocab(map(lambda x: x[1][1], self.examples))