コード例 #1
0
ファイル: metrics.py プロジェクト: zerogerc/rnn-autocomplete
class NonTerminalTerminalAccuracyMetrics(Metrics):

    def __init__(self):
        super().__init__()
        self.nt_accuracy = MaxPredictionAccuracyMetrics()
        self.t_accuracy = MaxPredictionAccuracyMetrics()

    def drop_state(self):
        self.nt_accuracy.drop_state()
        self.t_accuracy.drop_state()

    def report(self, data):
        nt_prediction, t_prediction, nt_target, t_target = data
        self.nt_accuracy.report((nt_prediction, nt_target))
        self.t_accuracy.report((t_prediction, t_target))

    def get_current_value(self, should_print=False):
        nt_value = self.nt_accuracy.get_current_value(should_print=False)
        t_value = self.t_accuracy.get_current_value(should_print=False)

        if should_print:
            print('Non terminals accuracy: {}'.format(nt_value))
            print('Terminals accuracy: {}'.format(t_value))

        return nt_value, t_value
コード例 #2
0
ファイル: results.py プロジェクト: zerogerc/rnn-autocomplete
def print_results(args):
    # assert args.prediction == 'nt2n'

    # seed = 1000
    # random.seed(seed)
    # numpy.random.seed(seed)

    main = get_main(args)

    routine = main.validation_routine

    metrics = MaxPredictionAccuracyMetrics()
    metrics.drop_state()
    main.model.eval()

    for iter_num, iter_data in enumerate(
            tqdm_lim(main.data_generator.get_eval_generator(), lim=1000)):
        metrics_data = routine.run(iter_num, iter_data)
        metrics.report(metrics_data)
        metrics.get_current_value(should_print=True)