示例#1
0
 def __init__(self,
              args=None,
              vocab=None,
              emb_matrix=None,
              model_file=None,
              use_cuda=False):
     self.use_cuda = use_cuda
     if model_file is not None:
         # load from file
         self.load(model_file, use_cuda)
     else:
         self.args = args
         self.model = None if args['dict_only'] else Seq2SeqModel(
             args, emb_matrix=emb_matrix)
         self.vocab = vocab
         self.expansion_dict = dict()
     if not self.args['dict_only']:
         self.crit = loss.SequenceLoss(self.vocab.size)
         self.parameters = [
             p for p in self.model.parameters() if p.requires_grad
         ]
         if use_cuda:
             self.model.cuda()
             self.crit.cuda()
         else:
             self.model.cpu()
             self.crit.cpu()
         self.optimizer = utils.get_optimizer(self.args['optim'],
                                              self.parameters,
                                              self.args['lr'])
示例#2
0
 def __init__(self, args=None, vocab=None, emb_matrix=None, model_file=None, use_cuda=False):
     self.use_cuda = use_cuda
     if model_file is not None:
         # load everything from file
         self.load(model_file, use_cuda)
     else:
         # build model from scratch
         self.args = args
         self.model = None if args['dict_only'] else Seq2SeqModel(args, emb_matrix=emb_matrix, use_cuda=use_cuda)
         self.vocab = vocab
         # dict-based components
         self.word_dict = dict()
         self.composite_dict = dict()
     if not self.args['dict_only']:
         if self.args.get('edit', False):
             self.crit = loss.MixLoss(self.vocab['char'].size, self.args['alpha'])
             logger.debug("Running seq2seq lemmatizer with edit classifier...")
         else:
             self.crit = loss.SequenceLoss(self.vocab['char'].size)
         self.parameters = [p for p in self.model.parameters() if p.requires_grad]
         if use_cuda:
             self.model.cuda()
             self.crit.cuda()
         else:
             self.model.cpu()
             self.crit.cpu()
         self.optimizer = utils.get_optimizer(self.args['optim'], self.parameters, self.args['lr'])
示例#3
0
 def load(self, filename, use_cuda=False):
     try:
         checkpoint = torch.load(filename, lambda storage, loc: storage)
     except BaseException:
         logger.exception("Cannot load model from {}".format(filename))
         sys.exit(1)
     self.args = checkpoint['config']
     self.expansion_dict = checkpoint['dict']
     if not self.args['dict_only']:
         self.model = Seq2SeqModel(self.args, use_cuda=use_cuda)
         self.model.load_state_dict(checkpoint['model'])
     else:
         self.model = None
     self.vocab = Vocab.load_state_dict(checkpoint['vocab'])
示例#4
0
 def load(self, filename, use_cuda=False):
     try:
         checkpoint = torch.load(filename, lambda storage, loc: storage)
     except BaseException:
         logger.error("Cannot load model from {}".format(filename))
         raise
     self.args = checkpoint['config']
     self.word_dict, self.composite_dict = checkpoint['dicts']
     if not self.args['dict_only']:
         self.model = Seq2SeqModel(self.args, use_cuda=use_cuda)
         self.model.load_state_dict(checkpoint['model'])
     else:
         self.model = None
     self.vocab = MultiVocab.load_state_dict(checkpoint['vocab'])