示例#1
0
def test_should_not_override_old_inference_log(patched_dataset_reading,
                                               patched_params):
    run_data = gen.run_data(model=MnistContrastiveModel())

    inference.single_run_inference(run_data=run_data, show=False)
    infer_results_dir_path = filenames.get_infer_dir(run_data)

    old_cr_times = [
        x.stat().st_ctime
        for x in list(Path(infer_results_dir_path).iterdir())
        if x.suffix == consts.LOG
    ]

    inference.single_run_inference(run_data=run_data, show=False)

    new_cr_times = [
        x.stat().st_ctime
        for x in list(Path(infer_results_dir_path).iterdir())
        if x.suffix == consts.LOG
    ]

    assert len(old_cr_times) == 1
    assert len(new_cr_times) == 2

    assert old_cr_times[0] in new_cr_times
示例#2
0
def test_should_create_correct_infer_directory_for_single_model_launcher(
        patched_params):
    run_data = gen.run_data()
    result = filenames.get_infer_dir(run_data)
    expected = filenames._get_home_infer_dir() / run_data.model.summary

    assert result == expected
示例#3
0
def test_should_create_correct_infer_directory_for_experiment_launcher(
        patched_params):
    run_data = gen.run_data(is_experiment=True)
    result = filenames.get_infer_dir(run_data)
    expected = filenames._get_home_infer_dir(
    ) / run_data.launcher_name / run_data.model.summary

    assert result == expected
示例#4
0
def test_should_create_summaries_for_different_models(patched_dataset_reading,
                                                      patched_params):
    model = patched_dataset_reading.param
    run_data = gen.run_data(model=model())

    inference.single_run_inference(run_data=run_data, show=False)

    infer_results_dir_path = filenames.get_infer_dir(run_data)
    assert utils.check_filepath(infer_results_dir_path,
                                is_directory=True,
                                is_empty=False)
示例#5
0
def prepare_infer_env(run_data: RunData):
    config.update_model_params(run_data.model.params)
    config.update_launcher_params(run_data.launcher_params)

    inference_dir = filenames.get_infer_dir(run_data)
    filename = filenames.create_infer_log_name(run_data.model)
    _set_logging_handlers([(inference_dir / filename)])

    utils.log("Inference data will be saved into: {}".format(inference_dir))
    _check_model_checkpoint_existence(run_data)
    _log_inference_model(run_data)
示例#6
0
def copy_text_log(run_data):
    inference_dir = filenames.get_infer_dir(run_data)
    run_text_logs_dir = filenames.get_run_text_logs_dir(run_data)
    if not run_text_logs_dir.exists():
        utils.log(
            "{} not exists - not copying text log".format(run_text_logs_dir))
        return
    latest_log = find_latest_ing_dir(run_text_logs_dir)
    import shutil
    shutil.copy(str(latest_log.absolute()),
                str((inference_dir / latest_log.name).absolute()))
示例#7
0
def test_should_create_correct_number_of_inference_files(
        patched_dataset_reading, patched_params):
    model = patched_dataset_reading.param()
    run_data = gen.run_data(model=model)

    inference.single_run_inference(run_data=run_data, show=False)

    infer_results_dir_path = filenames.get_infer_dir(run_data)
    expected_file_count = 4 if model.produces_2d_embedding else 2
    assert utils.check_filepath(infer_results_dir_path,
                                is_directory=True,
                                is_empty=False,
                                expected_len=expected_file_count)
示例#8
0
def plot_predicted_data(run_data: RunData, dicts_dataset: DictsDataset,
                        predictions: Dict[str, np.ndarray], show):
    model = run_data.model
    inference_dir = filenames.get_infer_dir(run_data)

    utils.log("Plotting pairs...")
    image_summaries.create_pairs_board(
        dataset=dicts_dataset,
        predicted_labels=model.get_predicted_labels(predictions),
        predicted_scores=model.get_predicted_scores(predictions),
        path=inference_dir /
        filenames.summary_to_name(model,
                                  suffix=consts.PNG,
                                  with_date_fragment=False,
                                  name=consts.INFER_PLOT_BOARD_NAME),
        show=show)
    if model.produces_2d_embedding:
        utils.log("Plotting distances...")
        x, y = map_pair_of_points_to_plot_data(
            predictions[consts.INFERENCE_LEFT_EMBEDDINGS],
            predictions[consts.INFERENCE_RIGHT_EMBEDDINGS])
        labels = dicts_dataset.labels
        image_summaries.create_distances_plot(
            left_coors=x,
            right_coors=y,
            labels_dict=labels,
            infer_result=predictions,
            path=inference_dir /
            filenames.summary_to_name(model,
                                      suffix=consts.PNG,
                                      with_date_fragment=False,
                                      name=consts.INFER_PLOT_DISTANCES_NAME),
            show=show)

        utils.log("Plotting clusters...")
        image_summaries.create_clusters_plot(
            feat=np.concatenate(
                (predictions[consts.INFERENCE_LEFT_EMBEDDINGS],
                 predictions[consts.INFERENCE_RIGHT_EMBEDDINGS])),
            labels=np.concatenate((labels.left, labels.right)),
            path=inference_dir /
            filenames.summary_to_name(model,
                                      suffix=consts.PNG,
                                      with_date_fragment=False,
                                      name=INFER_PLOT_CLUSTERS_NAME),
            show=show)
示例#9
0
def test_should_override_plots_with_newer_inference(patched_dataset_reading,
                                                    patched_params):
    run_data = gen.run_data(model=MnistContrastiveModel())

    inference.single_run_inference(run_data=run_data, show=False)
    infer_results_dir_path = filenames.get_infer_dir(run_data)

    old_cr_times = [
        x.stat().st_ctime
        for x in list(Path(infer_results_dir_path).iterdir())
        if x.suffix != consts.LOG
    ]

    inference.single_run_inference(run_data=run_data, show=False)

    new_cr_times = [
        x.stat().st_ctime
        for x in list(Path(infer_results_dir_path).iterdir())
        if x.suffix != consts.LOG
    ]

    var = [(x not in new_cr_times) for x in old_cr_times]
    assert_that(var, only_contains(True))
示例#10
0
        print("Connecting to remote host...")
        ssh_client.connect(hostname=secrets.REMOTE_HOST, username=secrets.USER, password=secrets.PASSWORD)

        filename = 'infer.zip'
        full_name = infer_dir / filename
        print("Zipping contests into {}".format(str(full_name)))
        (a, b, c) = ssh_client.exec_command("cd {}; zip -r {} ./".format(str(infer_dir), str(full_name)))
        print(b.readlines())
        print(c.readlines())

        with ssh_client.open_sftp() as ftp_client:
            localdir = Path("/tmp") / str(infer_dir.parts[-2]) / (
                        str(infer_dir.parts[-1]) + '_' + str(time.strftime('d%y%m%dt%H%M%S')))
            localdir.mkdir(parents=True, exist_ok=True)

            print("Downloading {} into {}".format(full_name, str(localdir / filename)))
            ftp_client.get(str(infer_dir / filename), str(localdir / filename))

    print("Connections closed.")
    show_images_in_dir_via_eog(localdir, filename)


if __name__ == '__main__':
    run_data = providing_launcher.provide_single_run_data()
    config.update_model_params(run_data.model.params)
    config.update_launcher_params(run_data.launcher_params)

    inference_dir = filenames.get_infer_dir(run_data)
    inference_dir = Path(str(inference_dir).replace('antek', 'ant'))
    ssh_download_and_open(inference_dir)