Beispiel #1
0
    def load_model(self):
        if not self.args.train_from:
            model = CopyRNN(self.args, self.vocab2id)
        else:
            model_path = self.args.train_from
            config_path = os.path.join(os.path.dirname(model_path),
                                       self.get_basename(model_path) + '.json')

            old_config = read_json(config_path)
            old_config['train_from'] = model_path
            old_config['step'] = int(model_path.rsplit('_', 1)[-1].split('.')[0])
            self.args = Munch(old_config)
            self.vocab2id = load_vocab(self.args.vocab_path, self.args.vocab_size)

            model = CopyRNN(self.args, self.vocab2id)

            if torch.cuda.is_available():
                checkpoint = torch.load(model_path)
            else:
                checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
            state_dict = OrderedDict()
            # avoid error when load parallel trained model
            for k, v in checkpoint.items():
                if k.startswith('module.'):
                    k = k[7:]
                state_dict[k] = v
            model.load_state_dict(state_dict)

        return model
Beispiel #2
0
    def __init__(self, args, model):
        torch.manual_seed(0)
        torch.autograd.set_detect_anomaly(True)
        self.args = args
        self.vocab2id = load_vocab(self.args.vocab_path, self.args.vocab_size)

        self.model = model
        if torch.cuda.is_available():
            self.model = self.model.cuda()
        if args.train_parallel:
            self.model = nn.DataParallel(self.model)
        self.loss_func = nn.NLLLoss(ignore_index=self.vocab2id[PAD_WORD])
        self.optimizer = optim.Adam(self.model.parameters(),
                                    lr=self.args.learning_rate)
        self.scheduler = optim.lr_scheduler.StepLR(self.optimizer,
                                                   self.args.schedule_step,
                                                   self.args.schedule_gamma)
        self.logger = get_logger('train')
        self.train_loader = KeyphraseDataLoader(
            data_source=self.args.train_filename,
            vocab2id=self.vocab2id,
            mode='train',
            args=args)
        if self.args.train_from:
            self.dest_dir = os.path.dirname(self.args.train_from) + '/'
        else:
            timemark = time.strftime('%Y%m%d-%H%M%S',
                                     time.localtime(time.time()))
            self.dest_dir = os.path.join(
                self.args.dest_base_dir,
                self.args.exp_name + '-' + timemark) + '/'
            os.mkdir(self.dest_dir)

        fh = logging.FileHandler(os.path.join(self.dest_dir, args.logfile))
        fh.setLevel(logging.INFO)
        fh.setFormatter(logging.Formatter('[%(asctime)s] %(message)s'))
        self.logger.addHandler(fh)

        if not self.args.tensorboard_dir:
            tensorboard_dir = self.dest_dir + 'logs/'
        else:
            tensorboard_dir = self.args.tensorboard_dir
        self.writer = SummaryWriter(tensorboard_dir)
        self.eval_topn = (5, 10)
        self.macro_evaluator = KeyphraseEvaluator(self.eval_topn, 'macro',
                                                  args.token_field,
                                                  args.keyphrase_field)
        self.micro_evaluator = KeyphraseEvaluator(self.eval_topn, 'micro',
                                                  args.token_field,
                                                  args.keyphrase_field)
        self.best_f1 = None
        self.best_step = 0
        self.not_update_count = 0
Beispiel #3
0
    def __init__(self, model_info, vocab_info, beam_size, max_target_len,
                 max_src_length):
        """

        :param model_info: input the model information.
                            str type: model path
                            dict type: must have `model` and `config` field,
                                        indicate the model object and config object

        :param vocab_info: input the vocab information.
                            str type: vocab path
                            dict type: vocab2id dict which map word to id
        :param beam_size: beam size
        :param max_target_len: max keyphrase token length
        :param max_src_length: max source text length
        """
        super().__init__(model_info)
        if isinstance(vocab_info, str):
            self.vocab2id = load_vocab(vocab_info)
        elif isinstance(vocab_info, dict):
            self.vocab2id = vocab_info
        else:
            raise ValueError('vocab info type error')
        self.id2vocab = dict(zip(self.vocab2id.values(), self.vocab2id.keys()))
        self.config = self.load_config(model_info)
        self.model = self.load_model(model_info,
                                     CopyRNN(self.config, self.vocab2id))
        self.model.eval()
        self.beam_size = beam_size
        self.max_target_len = max_target_len
        self.max_src_len = max_src_length
        self.beam_searcher = BeamSearch(model=self.model,
                                        beam_size=self.beam_size,
                                        max_target_len=self.max_target_len,
                                        id2vocab=self.id2vocab,
                                        bos_idx=self.vocab2id[BOS_WORD],
                                        unk_idx=self.vocab2id[UNK_WORD],
                                        args=self.config)
        self.pred_base_config = {
            'max_oov_count': self.config.max_oov_count,
            'max_src_len': self.max_src_len,
            'max_target_len': self.max_target_len,
            'prefetch': False,
            'shuffle_in_batch': False,
            'token_field': TOKENS,
            'keyphrase_field': 'keyphrases'
        }
Beispiel #4
0
 def __init__(self):
     self.args = self.parse_args()
     self.vocab2id = load_vocab(self.args.vocab_path)
     self.dest_base_dir = self.args.dest_base_dir
     self.writer = tf.summary.create_file_writer(self.dest_base_dir +
                                                 '/logs')
     self.exp_name = self.args.exp_name
     self.pad_idx = self.vocab2id[PAD_WORD]
     self.eval_topn = (5, 10)
     self.macro_evaluator = KeyphraseEvaluator(self.eval_topn, 'macro',
                                               self.args.token_field,
                                               self.args.keyphrase_field)
     self.micro_evaluator = KeyphraseEvaluator(self.eval_topn, 'micro',
                                               self.args.token_field,
                                               self.args.keyphrase_field)
     self.best_f1 = None
     self.best_step = 0
     self.not_update_count = 0
     self.logger = get_logger(__name__)
     self.total_vocab_size = len(self.vocab2id) + self.args.max_oov_count
Beispiel #5
0
 def __init__(self, model_info, vocab_info, beam_size, max_target_len, max_src_length):
     super().__init__(model_info)
     if isinstance(vocab_info, str):
         self.vocab2id = load_vocab(vocab_info)
     elif isinstance(vocab_info, dict):
         self.vocab2id = vocab_info
     else:
         raise ValueError('vocab info type error')
     self.id2vocab = dict(zip(self.vocab2id.values(), self.vocab2id.keys()))
     self.config = self.load_config(model_info)
     self.model = self.load_model(model_info, CopyTransformer(self.config, self.vocab2id))
     self.model.eval()
     self.beam_size = beam_size
     self.max_target_len = max_target_len
     self.max_src_len = max_src_length
     self.beam_searcher = TransformerBeamSearch(model=self.model,
                                                beam_size=self.beam_size,
                                                max_target_len=self.max_target_len,
                                                id2vocab=self.id2vocab,
                                                bos_idx=self.vocab2id[BOS_WORD],
                                                args=self.config)
     self.pred_base_config = {'max_oov_count': self.config.max_oov_count,
                              'max_src_len': self.max_src_len,
                              'max_target_len': self.max_target_len}
Beispiel #6
0
 def __init__(self):
     self.args = self.parse_args()
     self.vocab2id = load_vocab(self.args.vocab_path, self.args.vocab_size)
     model = self.load_model()
     super().__init__(self.args, model)
Beispiel #7
0
 def __init__(self):
     args = self.parse_args()
     vocab2id = load_vocab(args.vocab_path, vocab_size=args.vocab_size)
     model = CopyTransformer(args, vocab2id)
     super().__init__(args, model)