def __init__(self, terminal: Terminal, model: NPIStep = None, recording=True, max_depth=100, max_step=10000, result_logger=ResultLogger('result_multiplication.log')): self.terminal = terminal self.model = model self.steps = 0 self.step_list = [] self.alpha = 0.5 self.verbose = True self.recording = recording self.max_depth = max_depth self.max_step = max_step self.result_logger = result_logger
questions = create_questions(num) teacher = AdditionTeacher(program_set) npi_runner = TerminalNPIRunner(terminal, teacher) npi_runner.verbose = DEBUG_MODE steps_list = [] for data in questions: addition_env.reset() q = copy(data) run_npi(addition_env, npi_runner, program_set.ADD, data) steps_list.append({"q": q, "steps": npi_runner.step_list}) result_logger.write(data) terminal.add_log(data) if filename: with open(filename, 'wb') as f: pickle.dump(steps_list, f, protocol=pickle.HIGHEST_PROTOCOL) if __name__ == '__main__': import sys DEBUG_MODE = os.environ.get('DEBUG') if DEBUG_MODE: output_filename = None num_data = 3 log_filename = 'result.log' else: output_filename = sys.argv[1] if len(sys.argv) > 1 else None num_data = int(sys.argv[2]) if len(sys.argv) > 2 else 1000 log_filename = sys.argv[3] if len(sys.argv) > 3 else 'result.log' curses.wrapper(main, output_filename, num_data, ResultLogger(log_filename)) print("create %d training data" % num_data)
system = RuntimeSystem(terminal=terminal) npi_model = BubblesortNPIModel(system, model_path, program_set) npi_runner = TerminalNPIRunner(terminal, npi_model, recording=False) npi_runner.verbose = DEBUG_MODE correct_count = wrong_count = 0 for data in questions: Bubblesort_env.reset() try: run_npi(Bubblesort_env, npi_runner, program_set.BUBBLESORT, data) if data['correct']: correct_count += 1 else: wrong_count += 1 except StopIteration: wrong_count += 1 pass result_logger.write(data) terminal.add_log(data) return correct_count, wrong_count if __name__ == '__main__': import sys DEBUG_MODE = os.environ.get('DEBUG') model_path_ = sys.argv[1] num_data = int(sys.argv[2]) if len(sys.argv) > 2 else 1000 log_filename = sys.argv[3] if len(sys.argv) > 3 else 'result.log' cc, wc = curses.wrapper(main, model_path_, num_data, ResultLogger(log_filename)) print("Accuracy %s(OK=%d, NG=%d)" % (cc / (cc + wc), cc, wc))