class NonTerminalTerminalAccuracyMetrics(Metrics): def __init__(self): super().__init__() self.nt_accuracy = MaxPredictionAccuracyMetrics() self.t_accuracy = MaxPredictionAccuracyMetrics() def drop_state(self): self.nt_accuracy.drop_state() self.t_accuracy.drop_state() def report(self, data): nt_prediction, t_prediction, nt_target, t_target = data self.nt_accuracy.report((nt_prediction, nt_target)) self.t_accuracy.report((t_prediction, t_target)) def get_current_value(self, should_print=False): nt_value = self.nt_accuracy.get_current_value(should_print=False) t_value = self.t_accuracy.get_current_value(should_print=False) if should_print: print('Non terminals accuracy: {}'.format(nt_value)) print('Terminals accuracy: {}'.format(t_value)) return nt_value, t_value
def print_results(args): # assert args.prediction == 'nt2n' # seed = 1000 # random.seed(seed) # numpy.random.seed(seed) main = get_main(args) routine = main.validation_routine metrics = MaxPredictionAccuracyMetrics() metrics.drop_state() main.model.eval() for iter_num, iter_data in enumerate( tqdm_lim(main.data_generator.get_eval_generator(), lim=1000)): metrics_data = routine.run(iter_num, iter_data) metrics.report(metrics_data) metrics.get_current_value(should_print=True)