Exemple #1
0
def test_resumable_pipeline_fit_transform_should_save_all_fitted_pipeline_steps(
        tmpdir: LocalPath):
    p = ResumablePipeline(
        [(SOME_STEP_1, MultiplyByN(multiply_by=2)),
         (PIPELINE_2,
          ResumablePipeline([(SOME_STEP_2, MultiplyByN(multiply_by=4)),
                             (CHECKPOINT, DefaultCheckpoint()),
                             (SOME_STEP_3, MultiplyByN(multiply_by=6))]))],
        cache_folder=tmpdir)
    p.name = ROOT

    p, outputs = p.fit_transform(np.array(range(10)), np.array(range(10)))

    not_saved_paths = [create_some_step3_path(tmpdir)]
    saved_paths = [
        create_root_path(tmpdir),
        create_pipeline2_path(tmpdir),
        create_some_step1_path(tmpdir),
        create_some_step2_path(tmpdir),
        create_some_checkpoint_path(tmpdir)
    ]
    assert np.array_equal(outputs, EXPECTED_OUTPUTS)
    for p in saved_paths:
        assert os.path.exists(p)
    for p in not_saved_paths:
        assert not os.path.exists(p)
Exemple #2
0
def given_saved_pipeline(tmpdir: LocalPath):
    step_savers = [(SOME_STEP_1, []),
                   (PIPELINE_2, [TruncableJoblibStepSaver()])]
    path = create_root_path(tmpdir, True)
    root = ResumablePipeline([], cache_folder=tmpdir)
    root.sub_steps_savers = step_savers
    root.name = ROOT
    dump(root, path)

    pipeline_2 = ResumablePipeline([], cache_folder=tmpdir)
    pipeline_2.name = 'pipeline2'
    pipeline_2.sub_steps_savers = [
        (SOME_STEP_2, []),
        (CHECKPOINT, []),
        (SOME_STEP_3, []),
    ]
    dump(pipeline_2, create_pipeline2_path(tmpdir, True))

    given_saved_some_step(multiply_by=2,
                          name=SOME_STEP_1,
                          path=create_some_step1_path(tmpdir, True))
    given_saved_some_step(multiply_by=4,
                          name=SOME_STEP_2,
                          path=create_some_step2_path(tmpdir, True))
    given_saved_some_step(multiply_by=6,
                          name=SOME_STEP_3,
                          path=create_some_step3_path(tmpdir, True))

    checkpoint = DefaultCheckpoint()
    checkpoint.name = CHECKPOINT
    dump(checkpoint, create_some_checkpoint_path(tmpdir, True))

    p = ResumablePipeline(
        [(SOME_STEP_1, MultiplyByN(multiply_by=1)),
         (PIPELINE_2,
          ResumablePipeline([(SOME_STEP_2, MultiplyByN(multiply_by=1)),
                             (CHECKPOINT, DefaultCheckpoint()),
                             (SOME_STEP_3, MultiplyByN(multiply_by=1))]))],
        cache_folder=tmpdir)
    p.name = ROOT

    return p