コード例 #1
0
def test_queued_pipeline_saving(tmpdir):
    # Given
    p = ParallelQueuedFeatureUnion([
        ('1', FitTransformCallbackStep()),
        ('2', FitTransformCallbackStep()),
        ('3', FitTransformCallbackStep()),
        ('4', FitTransformCallbackStep()),
    ], n_workers_per_step=1, max_queue_size=10, batch_size=10)

    # When
    p, outputs = p.fit_transform(list(range(100)), list(range(100)))
    p.save(ExecutionContext(tmpdir))
    p.apply('clear_callbacks')

    # Then

    assert len(p[0].wrapped.transform_callback_function.data) == 0
    assert len(p[0].wrapped.fit_callback_function.data) == 0
    assert len(p[1].wrapped.transform_callback_function.data) == 0
    assert len(p[1].wrapped.fit_callback_function.data) == 0
    assert len(p[2].wrapped.transform_callback_function.data) == 0
    assert len(p[2].wrapped.fit_callback_function.data) == 0
    assert len(p[3].wrapped.transform_callback_function.data) == 0
    assert len(p[3].wrapped.fit_callback_function.data) == 0

    p = p.load(ExecutionContext(tmpdir))

    assert len(p[0].wrapped.transform_callback_function.data) == 10
    assert len(p[0].wrapped.fit_callback_function.data) == 10
    assert len(p[1].wrapped.transform_callback_function.data) == 10
    assert len(p[1].wrapped.fit_callback_function.data) == 10
    assert len(p[2].wrapped.transform_callback_function.data) == 10
    assert len(p[2].wrapped.fit_callback_function.data) == 10
    assert len(p[3].wrapped.transform_callback_function.data) == 10
    assert len(p[3].wrapped.fit_callback_function.data) == 10
コード例 #2
0
def test_queued_pipeline_saving(tmpdir, use_processes, use_savers):
    # Given
    p = ParallelQueuedFeatureUnion([
        ('1', 4, 10, FitTransformCallbackStep()),
        ('2', 4, 10, FitTransformCallbackStep()),
        ('3', 4, 10, FitTransformCallbackStep()),
        ('4', 4, 10, FitTransformCallbackStep()),
    ],
                                   n_workers_per_step=4,
                                   max_queue_size=10,
                                   batch_size=10,
                                   use_processes=use_processes,
                                   use_savers=use_savers).with_context(
                                       ExecutionContext(tmpdir))

    # When
    p, _ = p.fit_transform(list(range(200)), list(range(200)))
    p = p.wrapped  # clear execution context wrapper
    p.save(ExecutionContext(tmpdir))
    p.apply('clear_callbacks')

    # Then

    assert len(p[0].wrapped.transform_callback_function.data) == 0
    assert len(p[0].wrapped.fit_callback_function.data) == 0
    assert len(p[1].wrapped.transform_callback_function.data) == 0
    assert len(p[1].wrapped.fit_callback_function.data) == 0
    assert len(p[2].wrapped.transform_callback_function.data) == 0
    assert len(p[2].wrapped.fit_callback_function.data) == 0
    assert len(p[3].wrapped.transform_callback_function.data) == 0
    assert len(p[3].wrapped.fit_callback_function.data) == 0

    p = p.load(ExecutionContext(tmpdir))

    assert len(p[0].wrapped.transform_callback_function.data) == 20
    assert len(p[0].wrapped.fit_callback_function.data) == 20
    assert len(p[1].wrapped.transform_callback_function.data) == 20
    assert len(p[1].wrapped.fit_callback_function.data) == 20
    assert len(p[2].wrapped.transform_callback_function.data) == 20
    assert len(p[2].wrapped.fit_callback_function.data) == 20
    assert len(p[3].wrapped.transform_callback_function.data) == 20
    assert len(p[3].wrapped.fit_callback_function.data) == 20