def evaluate(self, logger: EvaluateLogger, filter_splits: Optional[Tuple[DataSplits]] = ( DataSplits.VALIDATION, ), dump_each: bool = False, num_replace_samples=1): self.model.set_in_eval_mode() # Kinda hacky approximation of sampling replacements multiple times # just iterate overy everything multiple times dups = [ list(self.data_pair_iterate(filter_splits)) for _ in range(num_replace_samples) ] dups = flatten_list(dups) dups.sort(key=lambda d: d[0].id) last_x = None last_eval_result = None last_example_id = None for data in dups: example, replaced_x_query, y_ast_set, this_example_ast, y_texts, rsample = data if last_x == replaced_x_query: logger.add_evaluation(last_eval_result) continue parse_exception = None try: prediction, metad = self.model.predict( replaced_x_query, example.get_y_type(self.example_store), True) except ModelCantPredictException as e: prediction = None parse_exception = e except ModelSafePredictError as e: prediction = None parse_exception = e #print("predict", prediction, "expect", y_ast_set, "ytext", y_texts, # "replx", replaced_x_query) eval = AstEvaluation(prediction, y_ast_set, y_texts, replaced_x_query, parse_exception, self.unparser) last_x = replaced_x_query last_eval_result = eval logger.add_evaluation(eval) if dump_each: if example.id != last_example_id: print("---") eval.print_vals(self.unparser) last_example_id = example.id self.model.set_in_train_mode()
def assert_acc(model, example_store, splits, required_accuracy, expect_fail): """Checks whether a model achieves a certain accuracy on a specified split""" trainer = TypeTranslateCFTrainer(model, example_store) logger = EvaluateLogger() trainer.evaluate(logger, splits) if not expect_fail: assert logger.stats['ExactMatch'].true_frac >= required_accuracy else: assert not logger.stats['ExactMatch'].true_frac >= required_accuracy
def eval_stuff(xs: List[str], ys: List[str], preds: List[str], y_ast_sets: Iterable[Tuple[AstObjectChoiceSet, Set[str]]], string_parser: StringParser, unparser: AstUnparser): logger = EvaluateLogger() assert len(xs) == len(preds) for x, pred, (y_ast_set, y_texts) in zip(xs, preds, y_ast_sets): pexception = None pred_ast = None try: pred_ast = string_parser.create_parse_tree( pred, y_ast_set.type_to_choose.name) except AInixParseError as e: pexception = e except RecursionError as e: warnings.warn(f"Max recursion meet. {pred}", RuntimeWarning) pexception = e evaluation = AstEvaluation(pred_ast, y_ast_set, y_texts, x, pexception, unparser, pred) evaluation.print_vals(unparser) logger.add_evaluation(evaluation) print_ast_eval_log(logger)
def train_func(pid, model: StringTypeTranslateCF, index: ExamplesStore, batch_size, epochs, force_single_thread, eval_thresh, total_proc_count, working_count: torch.multiprocessing.Value, continue_event, next_cont_event, total_proc_time_accum, all_done_event): if force_single_thread: os.environ["OMP_NUM_THREADS"] = "1" torch.set_num_threads = 1 print(f"start {pid} actual pid {os.getpid()}") #print(f"example count {index.get_num_x_values()}") #for i in range(10): # print(f"feelin racey? {len(list(index.get_all_x_values()))}") trainer = TypeTranslateCFTrainer(model, index, batch_size) if eval_thresh is None: trainer.train(epochs) else: for e in range(epochs): start_time = datetime.datetime.now() trainer.train(1) should_wait = True with working_count.get_lock(): working_count.value -= 1 if working_count.value == 0: working_count.value = total_proc_count # everyone done need to measure should_wait = False time_diff = (datetime.datetime.now() - start_time).total_seconds() #print("time diff", time_diff) total_proc_time_accum.value += time_diff #print("total val", total_proc_time_accum.value) logger = EvaluateLogger() trainer.evaluate(logger) acc = logger.stats['ExactMatch'].true_frac print(f"Curr acc {acc}") if acc >= eval_thresh: all_done_event.set() continue_event.set() # swap to new event. # still could have a race, but unlikely continue_event, next_cont_event = next_cont_event, continue_event continue_event.clear() if should_wait: continue_event.wait() continue_event, next_cont_event = next_cont_event, continue_event if all_done_event.is_set(): break
def train(self, epochs: int, eval_every_n_epochs: int = None, intermitted_save_path=None, dump_each_in_eval: bool = True, intermitted_repl_samples: int = 10): self.model.start_train_session() for epoch in tqdm(range(epochs), unit="Epochs"): print() print(f"Start epoch {epoch}") loss = self._train_one_epoch(epoch) print(f"\nEpoch {epoch} complete. Loss {loss}") if hasattr(self.model, "plz_train_this_latent_store_thanks"): # TODO wasdfahwerdfgv I should sleep # (yeah, even with sleep to lazy to fix this crappy interface. It works for now...) latent_store = self.model.plz_train_this_latent_store_thanks() if latent_store: print("updateding the latent store 🦔") update_latent_store_from_examples( self.model, latent_store, self.example_store, self.replacer, self.string_parser, (DataSplits.TRAIN, ), self.unparser, self.str_tokenizer) if eval_every_n_epochs and \ epoch + 1 != epochs and \ epoch % eval_every_n_epochs == 0 and \ epoch > 0: print("Pausing to do an eval") logger = EvaluateLogger() self.evaluate(logger, dump_each=dump_each_in_eval, num_replace_samples=intermitted_repl_samples) print_ast_eval_log(logger) if intermitted_save_path: if self.loader is None: raise ValueError("Must be given loader to serialize") s_path = f"{intermitted_save_path}_epoch{epoch}_exactmatch_" + \ f"{logger.stats['ExactMatch'].percent_true_str}" print(f"serializing to {s_path}") serialize(self.model, self.loader, s_path, eval_results=logger, trained_epochs=epoch) self.model.end_train_session()
def train_to_threshold_st(batch_size, threshold=0.7, max_epochs=100): index = get_index() num_docs = index.get_num_x_values() print("num docs", num_docs) model = get_default_encdec_model(index, standard_size=64) trainer = train.TypeTranslateCFTrainer(model, index, batch_size=batch_size) time_spent = 0 for epoch in range(max_epochs): start_time = datetime.datetime.now() trainer.train(1) time_spent += (datetime.datetime.now() - start_time).total_seconds() logger = EvaluateLogger() trainer.evaluate(logger) acc = logger.stats['ExactMatch'].true_frac print(f"Curr acc {acc}") if acc >= threshold: break return time_spent
if __name__ == "__main__": index_fac = example_store_fac([ '../../builtin_types/numbers.ainix.yaml', '../../builtin_types/generic_parsers.ainix.yaml' ], [ "../../builtin_types/numbers_examples.ainix.yaml" ]) batch_size = 4 index = index_fac() model = get_default_encdec_model(examples=index) # Try before print("before:") trainer = TypeTranslateCFTrainer(model, index, batch_size) logger = EvaluateLogger() trainer.evaluate(logger) print_ast_eval_log(logger) # train mptrainer = MultiprocTrainer( model, make_default_trainer_fac( model, batch_size ) ) mptrainer.train(5, 10, index, batch_size) logger = EvaluateLogger()