예제 #1
0
def test_random_split(create_random_dummy_data):

    no_train = {"test": 0.5,
                "eval": 0.5}

    # no train argument present in split map
    with pytest.raises(AssertionError):
        _ = RandomSplit(split_map=no_train)

    one_fold = {"train": 1.0}

    # only one argument present in split map
    with pytest.raises(AssertionError):
        _ = RandomSplit(split_map=one_fold)

    bogus_entries = {"train": 0.5,
                     "eval": "testest"}

    # not all entries in split map are floats
    with pytest.raises(AssertionError):
        _ = RandomSplit(split_map=bogus_entries)

    split_map = {"train": 1.0,
                 "eval": 0.0}

    random_split = RandomSplit(split_map=split_map)
    random_split_func, kwargs = random_split.partition_fn()

    # test defaults
    assert not random_split.schema
    assert not random_split.statistics

    dummy_data = create_random_dummy_data()

    split_folds = [random_split_func(ex,
                                     random_split.get_num_splits(),
                                     **kwargs) for ex in dummy_data]

    # artificial no split result tests, everything else is random
    assert all(fold == 0 for fold in split_folds)
예제 #2
0
from zenml.exceptions import AlreadyExistsException

# Define the training pipeline
training_pipeline = TrainingPipeline()

# Add a datasource. This will automatically track and version it.
try:
    ds = CSVDatasource(name='Pima Indians Diabetes',
                       path='gs://zenml_quickstart/diabetes.csv')
except AlreadyExistsException:
    ds = Repository.get_instance().get_datasource_by_name(
        'Pima Indians Diabetes')
training_pipeline.add_datasource(ds)

# Add a split
training_pipeline.add_split(RandomSplit(split_map={'train': 0.7, 'eval': 0.3}))

# Add a preprocessing unit
training_pipeline.add_preprocesser(
    StandardPreprocesser(features=[
        'times_pregnant', 'pgc', 'dbp', 'tst', 'insulin', 'bmi', 'pedigree',
        'age'
    ],
                         labels=['has_diabetes'],
                         overwrite={
                             'has_diabetes': {
                                 'transform': [{
                                     'method': 'no_transform',
                                     'parameters': {}
                                 }]
                             }
예제 #3
0
파일: run.py 프로젝트: sjoerdteunisse/zenml
nlp_pipeline = NLPPipeline()

try:
    ds = CSVDatasource(name="My Urdu Text",
                       path="gs://zenml_quickstart/urdu_fake_news.csv")
except AlreadyExistsException:
    ds = Repository.get_instance().get_datasource_by_name(name="My Urdu Text")

nlp_pipeline.add_datasource(ds)

tokenizer_step = HuggingFaceTokenizerStep(text_feature="news",
                                          tokenizer="bert-wordpiece",
                                          vocab_size=3000)

nlp_pipeline.add_tokenizer(tokenizer_step=tokenizer_step)

nlp_pipeline.add_split(RandomSplit(split_map={"train": 0.9, "eval": 0.1}))

nlp_pipeline.add_trainer(
    UrduTrainer(model_name="distilbert-base-uncased",
                epochs=3,
                batch_size=64,
                learning_rate=5e-3))

nlp_pipeline.run()

# evaluate the model with the sentence "The earth is flat"
# which should (ideally) return FAKE_NEWS
nlp_pipeline.predict_sentence("دنیا سیدھی ہے")
예제 #4
0
파일: run.py 프로젝트: Federicowengi/zenml
# Define the training pipeline
training_pipeline = TrainingPipeline()

# Add a datasource. This will automatically track and version it.
try:
    ds = CSVDatasource(name='Pima Indians Diabetes',
                       path='gs://zenml_quickstart/diabetes.csv')
except AlreadyExistsException:
    ds = Repository.get_instance().get_datasource_by_name(
        'Pima Indians Diabetes')
training_pipeline.add_datasource(ds)

# Add a split
training_pipeline.add_split(
    RandomSplit(split_map={
        'train': 0.7,
        'eval': 0.3
    }).with_backend(processing_backend))

# Add a preprocessing unit
training_pipeline.add_preprocesser(
    StandardPreprocesser(features=[
        'times_pregnant', 'pgc', 'dbp', 'tst', 'insulin', 'bmi', 'pedigree',
        'age'
    ],
                         labels=['has_diabetes'],
                         overwrite={
                             'has_diabetes': {
                                 'transform': [{
                                     'method': 'no_transform',
                                     'parameters': {}
                                 }]