예제 #1
0
 def __init__(self):
     self.global_step = 1
     self.max_f1 = 0
     self.keras_ema = ExponentialMovingAverage(model,
                                               decay=config.ema_decay,
                                               temp_model=os.path.join(
                                                   config.path,
                                                   'temp_model.h5'),
                                               type='cpu')
     super(Callback, self).__init__()
예제 #2
0
class QANet_callback(Callback):
    def __init__(self):
        self.global_step = 1
        self.max_f1 = 0
        self.keras_ema = ExponentialMovingAverage(model,
                                                  decay=config.ema_decay,
                                                  temp_model=os.path.join(
                                                      config.path,
                                                      'temp_model.h5'),
                                                  type='cpu')
        super(Callback, self).__init__()

    def on_train_begin(self, logs=None):
        lr = min(
            config.learning_rate, config.learning_rate /
            np.log(config.warm_up_steps) * np.log(self.global_step))
        K.set_value(self.model.optimizer.lr, lr)

    def on_batch_end(self, batch, logs=None):
        self.global_step += 1
        lr = min(
            config.learning_rate, config.learning_rate /
            np.log(config.warm_up_steps) * np.log(self.global_step))
        K.set_value(self.model.optimizer.lr, lr)
        self.keras_ema.average_update()

    def on_epoch_end(self, epoch, logs=None):
        self.keras_ema.assign_shadow_weights()
        logits1, logits2, _, _ = self.model.predict(
            x=[
                dev_data['context_id'], dev_data['question_id'],
                dev_data['context_char_id'], dev_data['question_char_id']
            ],
            batch_size=config.batch_size,
            verbose=1)
        all_results = []
        for i, qid in enumerate(dev_data['qid']):
            start_logits = logits1[i, :]
            end_logits = logits2[i, :]
            all_results.append(
                RawResult(qid=qid,
                          start_logits=start_logits,
                          end_logits=end_logits))
        output_prediction_file = os.path.join(config.path,
                                              'output_prediction.json')
        output_nbest_file = os.path.join(config.path, 'output_nbest.json')
        write_predictions(eval_examples,
                          eval_features,
                          all_results,
                          n_best_size=20,
                          max_answer_length=config.ans_limit,
                          do_lower_case=False,
                          output_prediction_file=output_prediction_file,
                          output_nbest_file=output_nbest_file)
        metrics = evaluate('original_data/dev-v1.1.json',
                           output_prediction_file, None)
        ems.append(metrics['exact'])
        f1s.append(metrics['f1'])
        result = pd.DataFrame([ems, f1s], index=['em', 'f1']).transpose()
        result.to_csv('logs/result_' + config.name + '.csv', index=None)
        if f1s[-1] > self.max_f1:
            self.max_f1 = f1s[-1]
            model.save_weights(
                os.path.join(config.path,
                             'QANet_model_' + config.name + '.h5'))
        model.load_weights(self.keras_ema.temp_model)