コード例 #1
0
ファイル: run.py プロジェクト: vingovan/zenml
from zenml.core.steps.trainer.feedforward_trainer import FeedForwardTrainer

artifact_store_path = 'gs://your-bucket-name/optional-subfolder'
project = 'PROJECT'  # the project to launch the VM in
cloudsql_connection_name = f'{project}:REGION:INSTANCE'
mysql_db = 'DATABASE'
mysql_user = '******'
mysql_pw = 'PASSWORD'
training_job_dir = artifact_store_path + '/gcaiptrainer/'

training_pipeline = TrainingPipeline(name='GCP Orchestrated')

# Add a datasource. This will automatically track and version it.
ds = CSVDatasource(name='Pima Indians Diabetes',
                   path='gs://zenml_quickstart/diabetes.csv')
training_pipeline.add_datasource(ds)

# Add a split
training_pipeline.add_split(RandomSplit(
    split_map={'train': 0.7, 'eval': 0.3}))

# Add a preprocessing unit
training_pipeline.add_preprocesser(
    StandardPreprocesser(
        features=['times_pregnant', 'pgc', 'dbp', 'tst', 'insulin', 'bmi',
                  'pedigree', 'age'],
        labels=['has_diabetes'],
        overwrite={'has_diabetes': {
            'transform': [{'method': 'no_transform', 'parameters': {}}]}}
    ))
コード例 #2
0
from examples.gan.gan_functions import CycleGANTrainer
from examples.gan.preprocessing import GANPreprocessor

repo: Repository = Repository().get_instance()

gan_pipeline = TrainingPipeline(name="whynotletitfly", enable_cache=False)

try:
    ds = ImageDatasource(
        name="gan_images",
        base_path="/Users/nicholasjunge/workspaces/maiot/ce_project/images_mini"
    )
except:
    ds = repo.get_datasource_by_name('gan_images')

gan_pipeline.add_datasource(ds)

gan_pipeline.add_split(
    CategoricalDomainSplit(categorical_column="label",
                           split_map={
                               "train": [0],
                               "eval": [1]
                           }))

gan_pipeline.add_preprocesser(GANPreprocessor())

# gan_pipeline.add_preprocesser(transform_step)

gan_pipeline.add_trainer(CycleGANTrainer(epochs=5))

gan_pipeline.run()