예제 #1
0
def test_random_split(create_random_dummy_data):

    no_train = {"test": 0.5, "eval": 0.5}

    # no train argument present in split map
    with pytest.raises(AssertionError):
        _ = RandomSplit(split_map=no_train)

    one_fold = {"train": 1.0}

    # only one argument present in split map
    with pytest.raises(AssertionError):
        _ = RandomSplit(split_map=one_fold)

    bogus_entries = {"train": 0.5, "eval": "testest"}

    # not all entries in split map are floats
    with pytest.raises(AssertionError):
        _ = RandomSplit(split_map=bogus_entries)

    split_map = {"train": 1.0, "eval": 0.0}

    random_split = RandomSplit(split_map=split_map)
    random_split_func, kwargs = random_split.partition_fn()

    # test defaults
    assert not random_split.schema
    assert not random_split.statistics

    dummy_data = create_random_dummy_data()

    split_folds = [
        random_split_func(ex, random_split.get_num_splits(), **kwargs)
        for ex in dummy_data
    ]

    # artificial no split result tests, everything else is random
    assert all(fold == 0 for fold in split_folds)
예제 #2
0
nlp_pipeline = NLPPipeline()

try:
    ds = CSVDatasource(name="my_text",
                       path="gs://zenml_quickstart/urdu_fake_news.csv")
except AlreadyExistsException:
    ds = Repository.get_instance().get_datasource_by_name(name="my_text")

nlp_pipeline.add_datasource(ds)

tokenizer_step = HuggingFaceTokenizerStep(text_feature="news",
                                          tokenizer="bert-wordpiece",
                                          vocab_size=3000)

nlp_pipeline.add_tokenizer(tokenizer_step=tokenizer_step)

nlp_pipeline.add_split(RandomSplit(split_map={"train": 0.9,
                                              "eval": 0.1}))

nlp_pipeline.add_trainer(UrduTrainer(model_name="distilbert-base-uncased",
                                     epochs=3,
                                     batch_size=64,
                                     learning_rate=5e-3))

nlp_pipeline.run()

# evaluate the model with the sentence "The earth is flat"
# which should (ideally) return FAKE_NEWS
nlp_pipeline.predict_sentence("دنیا سیدھی ہے")
예제 #3
0
파일: run.py 프로젝트: vingovan/zenml
project = 'PROJECT'  # the project to launch the VM in
cloudsql_connection_name = f'{project}:REGION:INSTANCE'
mysql_db = 'DATABASE'
mysql_user = '******'
mysql_pw = 'PASSWORD'
training_job_dir = artifact_store_path + '/gcaiptrainer/'

training_pipeline = TrainingPipeline(name='GCP Orchestrated')

# Add a datasource. This will automatically track and version it.
ds = CSVDatasource(name='Pima Indians Diabetes',
                   path='gs://zenml_quickstart/diabetes.csv')
training_pipeline.add_datasource(ds)

# Add a split
training_pipeline.add_split(RandomSplit(
    split_map={'train': 0.7, 'eval': 0.3}))

# Add a preprocessing unit
training_pipeline.add_preprocesser(
    StandardPreprocesser(
        features=['times_pregnant', 'pgc', 'dbp', 'tst', 'insulin', 'bmi',
                  'pedigree', 'age'],
        labels=['has_diabetes'],
        overwrite={'has_diabetes': {
            'transform': [{'method': 'no_transform', 'parameters': {}}]}}
    ))

# Add a trainer
training_pipeline.add_trainer(FeedForwardTrainer(
    loss='binary_crossentropy',
    last_activation='sigmoid',
예제 #4
0
파일: run.py 프로젝트: zilongqiu/zenml
# Define the training pipeline
training_pipeline = TrainingPipeline()

# Add a datasource. This will automatically track and version it.
try:
    ds = CSVDatasource(name='Pima Indians Diabetes',
                       path='gs://zenml_quickstart/diabetes.csv')
except AlreadyExistsException:
    ds = Repository.get_instance().get_datasource_by_name(
        'Pima Indians Diabetes')
training_pipeline.add_datasource(ds)

# Add a split
training_pipeline.add_split(
    RandomSplit(split_map={'train': 0.7, 'eval': 0.3}).with_backend(
        processing_backend)
)

# Add a preprocessing unit
training_pipeline.add_preprocesser(
    StandardPreprocesser(
        features=['times_pregnant', 'pgc', 'dbp', 'tst', 'insulin', 'bmi',
                  'pedigree', 'age'],
        labels=['has_diabetes'],
        overwrite={'has_diabetes': {
            'transform': [{'method': 'no_transform', 'parameters': {}}]}}
    ).with_backend(processing_backend)
)

# Add a trainer
training_pipeline.add_trainer(FeedForwardTrainer(