Exemple #1
0
    def evaluate(self,
                 logger: EvaluateLogger,
                 filter_splits: Optional[Tuple[DataSplits]] = (
                     DataSplits.VALIDATION, ),
                 dump_each: bool = False,
                 num_replace_samples=1):
        self.model.set_in_eval_mode()
        # Kinda hacky approximation of sampling replacements multiple times
        # just iterate overy everything multiple times
        dups = [
            list(self.data_pair_iterate(filter_splits))
            for _ in range(num_replace_samples)
        ]
        dups = flatten_list(dups)
        dups.sort(key=lambda d: d[0].id)

        last_x = None
        last_eval_result = None
        last_example_id = None
        for data in dups:
            example, replaced_x_query, y_ast_set, this_example_ast, y_texts, rsample = data
            if last_x == replaced_x_query:
                logger.add_evaluation(last_eval_result)
                continue
            parse_exception = None
            try:
                prediction, metad = self.model.predict(
                    replaced_x_query, example.get_y_type(self.example_store),
                    True)
            except ModelCantPredictException as e:
                prediction = None
                parse_exception = e
            except ModelSafePredictError as e:
                prediction = None
                parse_exception = e
            #print("predict", prediction, "expect", y_ast_set, "ytext", y_texts,
            #      "replx", replaced_x_query)
            eval = AstEvaluation(prediction, y_ast_set, y_texts,
                                 replaced_x_query, parse_exception,
                                 self.unparser)
            last_x = replaced_x_query
            last_eval_result = eval
            logger.add_evaluation(eval)
            if dump_each:
                if example.id != last_example_id:
                    print("---")
                eval.print_vals(self.unparser)
                last_example_id = example.id

        self.model.set_in_train_mode()
def assert_acc(model, example_store, splits, required_accuracy, expect_fail):
    """Checks whether a model achieves a certain accuracy on a specified split"""
    trainer = TypeTranslateCFTrainer(model, example_store)
    logger = EvaluateLogger()
    trainer.evaluate(logger, splits)
    if not expect_fail:
        assert logger.stats['ExactMatch'].true_frac >= required_accuracy
    else:
        assert not logger.stats['ExactMatch'].true_frac >= required_accuracy
Exemple #3
0
def eval_stuff(xs: List[str], ys: List[str], preds: List[str],
               y_ast_sets: Iterable[Tuple[AstObjectChoiceSet, Set[str]]],
               string_parser: StringParser, unparser: AstUnparser):
    logger = EvaluateLogger()
    assert len(xs) == len(preds)
    for x, pred, (y_ast_set, y_texts) in zip(xs, preds, y_ast_sets):
        pexception = None
        pred_ast = None
        try:
            pred_ast = string_parser.create_parse_tree(
                pred, y_ast_set.type_to_choose.name)
        except AInixParseError as e:
            pexception = e
        except RecursionError as e:
            warnings.warn(f"Max recursion meet. {pred}", RuntimeWarning)
            pexception = e
        evaluation = AstEvaluation(pred_ast, y_ast_set, y_texts, x, pexception,
                                   unparser, pred)
        evaluation.print_vals(unparser)
        logger.add_evaluation(evaluation)
    print_ast_eval_log(logger)
Exemple #4
0
def train_func(pid, model: StringTypeTranslateCF, index: ExamplesStore, batch_size, epochs,
               force_single_thread, eval_thresh, total_proc_count,
               working_count: torch.multiprocessing.Value,
               continue_event, next_cont_event, total_proc_time_accum,
               all_done_event):
    if force_single_thread:
        os.environ["OMP_NUM_THREADS"] = "1"
        torch.set_num_threads = 1
    print(f"start {pid} actual pid {os.getpid()}")
    #print(f"example count {index.get_num_x_values()}")
    #for i in range(10):
    #    print(f"feelin racey? {len(list(index.get_all_x_values()))}")
    trainer = TypeTranslateCFTrainer(model, index, batch_size)
    if eval_thresh is None:
        trainer.train(epochs)
    else:
        for e in range(epochs):
            start_time = datetime.datetime.now()
            trainer.train(1)
            should_wait = True
            with working_count.get_lock():
                working_count.value -= 1
                if working_count.value == 0:
                    working_count.value = total_proc_count
                    # everyone done need to measure
                    should_wait = False
                    time_diff = (datetime.datetime.now() - start_time).total_seconds()
                    #print("time diff", time_diff)
                    total_proc_time_accum.value += time_diff
                    #print("total val", total_proc_time_accum.value)
                    logger = EvaluateLogger()
                    trainer.evaluate(logger)
                    acc = logger.stats['ExactMatch'].true_frac
                    print(f"Curr acc {acc}")
                    if acc >= eval_thresh:
                        all_done_event.set()
                    continue_event.set()
                    # swap to new event.
                    # still could have a race, but unlikely
                    continue_event, next_cont_event = next_cont_event, continue_event
                    continue_event.clear()
            if should_wait:
                continue_event.wait()
                continue_event, next_cont_event = next_cont_event, continue_event
            if all_done_event.is_set():
                break
Exemple #5
0
    def train(self,
              epochs: int,
              eval_every_n_epochs: int = None,
              intermitted_save_path=None,
              dump_each_in_eval: bool = True,
              intermitted_repl_samples: int = 10):
        self.model.start_train_session()
        for epoch in tqdm(range(epochs), unit="Epochs"):
            print()
            print(f"Start epoch {epoch}")
            loss = self._train_one_epoch(epoch)
            print(f"\nEpoch {epoch} complete. Loss {loss}")
            if hasattr(self.model, "plz_train_this_latent_store_thanks"):
                # TODO wasdfahwerdfgv I should sleep
                # (yeah, even with sleep to lazy to fix this crappy interface. It works for now...)
                latent_store = self.model.plz_train_this_latent_store_thanks()
                if latent_store:
                    print("updateding the latent store 🦔")
                    update_latent_store_from_examples(
                        self.model, latent_store, self.example_store,
                        self.replacer, self.string_parser,
                        (DataSplits.TRAIN, ), self.unparser,
                        self.str_tokenizer)
            if eval_every_n_epochs and \
                    epoch + 1 != epochs and \
                    epoch % eval_every_n_epochs == 0 and \
                    epoch > 0:
                print("Pausing to do an eval")
                logger = EvaluateLogger()
                self.evaluate(logger,
                              dump_each=dump_each_in_eval,
                              num_replace_samples=intermitted_repl_samples)
                print_ast_eval_log(logger)
                if intermitted_save_path:
                    if self.loader is None:
                        raise ValueError("Must be given loader to serialize")
                    s_path = f"{intermitted_save_path}_epoch{epoch}_exactmatch_" + \
                             f"{logger.stats['ExactMatch'].percent_true_str}"
                    print(f"serializing to {s_path}")
                    serialize(self.model,
                              self.loader,
                              s_path,
                              eval_results=logger,
                              trained_epochs=epoch)

        self.model.end_train_session()
Exemple #6
0
def train_to_threshold_st(batch_size, threshold=0.7, max_epochs=100):
    index = get_index()
    num_docs = index.get_num_x_values()
    print("num docs", num_docs)

    model = get_default_encdec_model(index, standard_size=64)
    trainer = train.TypeTranslateCFTrainer(model, index, batch_size=batch_size)
    time_spent = 0
    for epoch in range(max_epochs):
        start_time = datetime.datetime.now()
        trainer.train(1)
        time_spent += (datetime.datetime.now() - start_time).total_seconds()
        logger = EvaluateLogger()
        trainer.evaluate(logger)
        acc = logger.stats['ExactMatch'].true_frac
        print(f"Curr acc {acc}")
        if acc >= threshold:
            break

    return time_spent
Exemple #7
0

if __name__ == "__main__":
    index_fac = example_store_fac([
        '../../builtin_types/numbers.ainix.yaml',
        '../../builtin_types/generic_parsers.ainix.yaml'
    ], [
        "../../builtin_types/numbers_examples.ainix.yaml"
    ])
    batch_size = 4
    index = index_fac()
    model = get_default_encdec_model(examples=index)
    # Try before
    print("before:")
    trainer = TypeTranslateCFTrainer(model, index, batch_size)
    logger = EvaluateLogger()
    trainer.evaluate(logger)
    print_ast_eval_log(logger)

    # train

    mptrainer = MultiprocTrainer(
        model,
        make_default_trainer_fac(
            model,
            batch_size
        )
    )
    mptrainer.train(5, 10, index, batch_size)

    logger = EvaluateLogger()