예제 #1
0
파일: main_dual.py 프로젝트: Ocxs/ADAPT
def main(_):
    project_name = 'DUAL'
    ft_params = DualParams(project_name)
    ft_params.post_process()
    mode = ft_params.params['mode']

    model_fn = DUALModel(ft_params.params)
    input_fn = DUALBaseData(ft_params.params)

    if mode == 'train':
        print('train&eval')
        estimator_solver.train(input_fn=input_fn,
                               model_fn=model_fn,
                               params=ft_params.params)
    elif mode == 'eval':
        print('eval')
        estimator_solver.evaluate(input_fn=input_fn,
                                  model_fn=model_fn,
                                  params=ft_params.params)
    elif mode == 'infer':
        print('predict')
        estimator_solver.predict(input_fn=input_fn,
                                 model_fn=model_fn,
                                 params=ft_params.params)
    else:
        raise ValueError('invalid mode: {}'.format(mode))

    model_dir = ft_params.params['model_dir']
    utils.stat_eval_results(model_dir, 'eval_result.txt')
예제 #2
0
파일: main_aden_ft.py 프로젝트: Ocxs/ADAPT
def main(_):
    project_name = 'ADEN_ft'
    seq_params = MixedSeqMultiModalFtParams(project_name)
    seq_params.post_process()
    mode = seq_params.params['mode']

    model_fn = MixedSeqMultiModalFtModel(seq_params.params)
    input_fn = MixedSeqMultiModalData(seq_params.params)

    if mode == 'train':
        print('train&eval')
        estimator_solver.train(input_fn=input_fn,
                               model_fn=model_fn,
                               params=seq_params.params)
    elif mode == 'eval':
        print('eval')
        estimator_solver.evaluate(input_fn=input_fn,
                                  model_fn=model_fn,
                                  params=seq_params.params)
    elif mode == 'infer':
        print('predict')
        estimator_solver.predict(input_fn=input_fn,
                                 model_fn=model_fn,
                                 params=seq_params.params)
    else:
        raise ValueError('invalid mode: {}'.format(mode))

    print('------------ evaluate ------------')
    tf.logging.info('------------ evaluate ------------')
    estimator_solver.evaluate(input_fn=input_fn,
                              model_fn=model_fn,
                              params=seq_params.params)
    model_dir = seq_params.params['model_dir']
    utils.stat_eval_results(model_dir, 'eval_result.txt')
예제 #3
0
 def build_eval_estimator_spec(self, mode):
     eval_est_spec = tf.estimator.EstimatorSpec(
         mode,
         loss=self.loss,
         eval_metric_ops={
             'metric/test/content_accuracy': self.accuracy,
             'metric/test/content_roc_auc': self.roc_auc,
         })
     utils.stat_eval_results(self.model_dir, 'eval_result_temp.txt')
     self.eval_est_spec = eval_est_spec
예제 #4
0
파일: DUAL.py 프로젝트: Ocxs/ADAPT
 def build_eval_estimator_spec(self, mode):
     utils.stat_eval_results(self.model_dir, 'eval_result_temp.txt')
     return tf.estimator.EstimatorSpec(
         mode,
         loss=self.loss,
     )