Esempio n. 1
0
def test_on_raw_chunks(model_path: RichPath,
                       test_hyper_overrides: Dict[str, Any],
                       snippet_output_folder: str,
                       proc_id: int,
                       test_raw_data_chunks: List[RichPath]) -> Tuple[int, List[float], int, int]:
    def write_snippet(snippet_id: int, content: str):
        with open(os.path.join(snippet_output_folder, 'sample_%i-%i.cs' % (proc_id, snippet_id)), 'w', encoding='utf-8') as f:
            f.write(content)

    results = {"correct_at_1": 0,
               "correct_at_5": 0,
               "token_perplexities": []}

    def per_result_callback(sample_idx, token_perplexity, raw_sample, sample_result):
        predictions = sample_result.all_predictions
        results["token_perplexities"].append(token_perplexity)
        if len(predictions) == 0:
            write_snippet(sample_idx, build_csharp_check_function(raw_sample, '???'))  # A full error
            return
        if token_seq_equal(predictions[0][0], sample_result.ground_truth):
            results["correct_at_1"] += 1
        if any(token_seq_equal(prediction[0], sample_result.ground_truth) for prediction in predictions[:5]):
            results["correct_at_5"] += 1
        write_snippet(sample_idx, build_csharp_check_function(raw_sample, ' '.join(predictions[0][0])))

    test_hyper_overrides['run_id'] = test_hyper_overrides['run_id'] + "-" + str(proc_id)
    test_hyper_overrides['gpu_device_id'] = ""
    train_model = model_restore_helper.restore(model_path, is_train=True, hyper_overrides=test_hyper_overrides)
    model = model_restore_helper.restore(model_path, is_train=False, hyper_overrides=test_hyper_overrides)
    num_samples = model.test(test_raw_data_chunks, per_result_callback_fn=per_result_callback, train_model=train_model)
    return num_samples, results["token_perplexities"], results["correct_at_1"], results["correct_at_5"]
Esempio n. 2
0
def run_train(model_class: Type[Model],
              train_data_path: RichPath,
              valid_data_path: RichPath,
              save_folder: str,
              hyperparameters: Dict[str, Any],
              run_name: Optional[str]=None,
              quiet: bool=False) \
        -> RichPath:
    train_data_chunk_paths = train_data_path.get_filtered_files_in_dir('chunk_*')
    valid_data_chunk_paths = valid_data_path.get_filtered_files_in_dir('valid_chunk_*')

    model = model_class(hyperparameters, run_name=run_name, model_save_dir=save_folder, log_save_dir=save_folder)
    if os.path.exists(model.model_save_path):
        model = model_restore_helper.restore(RichPath.create(model.model_save_path), is_train=True)
        model.train_log("Resuming training run %s of model %s with following hypers:\n%s" % (hyperparameters['run_id'],
                                                                                             model.__class__.__name__,
                                                                                             json.dumps(
                                                                                                 hyperparameters)))
        resume = True
    else:
        model.load_existing_metadata(train_data_path.join('metadata.pkl.gz'))
        model.make_model(is_train=True)
        model.train_log("Starting training run %s of model %s with following hypers:\n%s" % (hyperparameters['run_id'],
                                                                                             model.__class__.__name__,
                                                                                             json.dumps(hyperparameters)))
        resume = False
    model_path = model.train(train_data_chunk_paths, valid_data_chunk_paths, quiet=quiet, resume=resume)
    return model_path