示例#1
0
    def __init__(self,
                 args=None,
                 vocab=None,
                 pretrain=None,
                 model_file=None,
                 use_cuda=False,
                 train_classifier_only=False):
        self.use_cuda = use_cuda
        if model_file is not None:
            # load everything from file
            self.load(model_file, args)
        else:
            assert all(var is not None for var in [args, vocab, pretrain])
            # build model from scratch
            self.args = args
            self.vocab = vocab
            self.model = NERTagger(args, vocab, emb_matrix=pretrain.emb)

        if train_classifier_only:
            logger.info('Disabling gradient for non-classifier layers')
            exclude = ['tag_clf', 'crit']
            for pname, p in self.model.named_parameters():
                if pname.split('.')[0] not in exclude:
                    p.requires_grad = False
        self.parameters = [
            p for p in self.model.parameters() if p.requires_grad
        ]
        if self.use_cuda:
            self.model.cuda()
        else:
            self.model.cpu()
        self.optimizer = utils.get_optimizer(self.args['optim'],
                                             self.parameters,
                                             self.args['lr'],
                                             momentum=self.args['momentum'])
示例#2
0
 def __init__(self,
              args=None,
              vocab=None,
              pretrain=None,
              model_file=None,
              use_cuda=False):
     self.use_cuda = use_cuda
     if model_file is not None:
         # load everything from file
         self.load(model_file, args)
     else:
         assert all(var is not None for var in [args, vocab, pretrain])
         # build model from scratch
         self.args = args
         self.vocab = vocab
         self.model = NERTagger(args, vocab, emb_matrix=pretrain.emb)
     self.parameters = [
         p for p in self.model.parameters() if p.requires_grad
     ]
     if self.use_cuda:
         self.model.cuda()
     else:
         self.model.cpu()
     self.optimizer = utils.get_optimizer(self.args['optim'],
                                          self.parameters,
                                          self.args['lr'],
                                          momentum=self.args['momentum'])
示例#3
0
 def load(self, filename, args=None):
     try:
         checkpoint = torch.load(filename, lambda storage, loc: storage)
     except BaseException:
         logger.error("Cannot load model from {}".format(filename))
         raise
     self.args = checkpoint['config']
     if args: self.args.update(args)
     self.vocab = MultiVocab.load_state_dict(checkpoint['vocab'])
     self.model = NERTagger(self.args, self.vocab)
     self.model.load_state_dict(checkpoint['model'], strict=False)
示例#4
0
class Trainer(BaseTrainer):
    """ A trainer for training models. """
    def __init__(self,
                 args=None,
                 vocab=None,
                 pretrain=None,
                 model_file=None,
                 use_cuda=False,
                 train_classifier_only=False):
        self.use_cuda = use_cuda
        if model_file is not None:
            # load everything from file
            self.load(model_file, args)
        else:
            assert all(var is not None for var in [args, vocab, pretrain])
            # build model from scratch
            self.args = args
            self.vocab = vocab
            self.model = NERTagger(args, vocab, emb_matrix=pretrain.emb)

        if train_classifier_only:
            logger.info('Disabling gradient for non-classifier layers')
            exclude = ['tag_clf', 'crit']
            for pname, p in self.model.named_parameters():
                if pname.split('.')[0] not in exclude:
                    p.requires_grad = False
        self.parameters = [
            p for p in self.model.parameters() if p.requires_grad
        ]
        if self.use_cuda:
            self.model.cuda()
        else:
            self.model.cpu()
        self.optimizer = utils.get_optimizer(self.args['optim'],
                                             self.parameters,
                                             self.args['lr'],
                                             momentum=self.args['momentum'])

    def update(self, batch, eval=False):
        inputs, orig_idx, word_orig_idx, char_orig_idx, sentlens, wordlens, charlens, charoffsets = unpack_batch(
            batch, self.use_cuda)
        word, word_mask, wordchars, wordchars_mask, chars, tags = inputs

        if eval:
            self.model.eval()
        else:
            self.model.train()
            self.optimizer.zero_grad()
        loss, _, _ = self.model(word, word_mask, wordchars, wordchars_mask,
                                tags, word_orig_idx, sentlens, wordlens, chars,
                                charoffsets, charlens, char_orig_idx)
        loss_val = loss.data.item()
        if eval:
            return loss_val

        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                       self.args['max_grad_norm'])
        self.optimizer.step()
        return loss_val

    def predict(self, batch, unsort=True):
        inputs, orig_idx, word_orig_idx, char_orig_idx, sentlens, wordlens, charlens, charoffsets = unpack_batch(
            batch, self.use_cuda)
        word, word_mask, wordchars, wordchars_mask, chars, tags = inputs

        self.model.eval()
        batch_size = word.size(0)
        _, logits, trans = self.model(word, word_mask, wordchars,
                                      wordchars_mask, tags, word_orig_idx,
                                      sentlens, wordlens, chars, charoffsets,
                                      charlens, char_orig_idx)

        # decode
        trans = trans.data.cpu().numpy()
        scores = logits.data.cpu().numpy()
        bs = logits.size(0)
        tag_seqs = []
        for i in range(bs):
            tags, _ = viterbi_decode(scores[i, :sentlens[i]], trans)
            tags = self.vocab['tag'].unmap(tags)
            tag_seqs += [tags]

        if unsort:
            tag_seqs = utils.unsort(tag_seqs, orig_idx)
        return tag_seqs

    def save(self, filename, skip_modules=True):
        model_state = self.model.state_dict()
        # skip saving modules like pretrained embeddings, because they are large and will be saved in a separate file
        if skip_modules:
            skipped = [
                k for k in model_state.keys()
                if k.split('.')[0] in self.model.unsaved_modules
            ]
            for k in skipped:
                del model_state[k]
        params = {
            'model': model_state,
            'vocab': self.vocab.state_dict(),
            'config': self.args
        }
        try:
            torch.save(params, filename, _use_new_zipfile_serialization=False)
            logger.info("Model saved to {}".format(filename))
        except (KeyboardInterrupt, SystemExit):
            raise
        except:
            logger.warning("Saving failed... continuing anyway.")

    def load(self, filename, args=None):
        try:
            checkpoint = torch.load(filename, lambda storage, loc: storage)
        except BaseException:
            logger.error("Cannot load model from {}".format(filename))
            raise
        self.args = checkpoint['config']
        if args: self.args.update(args)
        self.vocab = MultiVocab.load_state_dict(checkpoint['vocab'])
        self.model = NERTagger(self.args, self.vocab)
        self.model.load_state_dict(checkpoint['model'], strict=False)