コード例 #1
0
from zenml.datasources import CSVDatasource
from zenml.pipelines import TrainingPipeline
from zenml.repo import Repository
from zenml.steps.evaluator import TFMAEvaluator
from zenml.steps.preprocesser import StandardPreprocesser
from zenml.steps.split import RandomSplit
from zenml.steps.trainer import TFFeedForwardTrainer
from zenml.exceptions import AlreadyExistsException

# Define the training pipeline
training_pipeline = TrainingPipeline()

# Add a datasource. This will automatically track and version it.
try:
    ds = CSVDatasource(name='Pima Indians Diabetes',
                       path='gs://zenml_quickstart/diabetes.csv')
except AlreadyExistsException:
    ds = Repository.get_instance().get_datasource_by_name(
        'Pima Indians Diabetes')
training_pipeline.add_datasource(ds)

# Add a split
training_pipeline.add_split(RandomSplit(split_map={'train': 0.7, 'eval': 0.3}))

# Add a preprocessing unit
training_pipeline.add_preprocesser(
    StandardPreprocesser(features=[
        'times_pregnant', 'pgc', 'dbp', 'tst', 'insulin', 'bmi', 'pedigree',
        'age'
    ],
                         labels=['has_diabetes'],
コード例 #2
0
csv_root = os.path.join(TEST_ROOT, "test_data")
image_root = os.path.join(csv_root, "images")


repo: Repository = Repository.get_instance()
if path_utils.is_dir(pipeline_root):
    path_utils.rm_dir(pipeline_root)
repo.zenml_config.set_pipelines_dir(pipeline_root)

try:
    for i in range(1, 6):
        training_pipeline = TrainingPipeline(name='csvtest{0}'.format(i))

        try:
            # Add a datasource. This will automatically track and version it.
            ds = CSVDatasource(name='my_csv_datasource',
                               path=os.path.join(csv_root, "my_dataframe.csv"))
        except AlreadyExistsException:
            ds = repo.get_datasource_by_name("my_csv_datasource")

        training_pipeline.add_datasource(ds)

        # Add a split
        training_pipeline.add_split(CategoricalDomainSplit(
            categorical_column="name",
            split_map={'train': ["arnold", "nicholas"], 'eval': ["lülük"]}))

        # Add a preprocessing unit
        training_pipeline.add_preprocesser(
            StandardPreprocesser(
                features=["name", "age"],
                labels=['gpa'],
コード例 #3
0
ファイル: run.py プロジェクト: sjoerdteunisse/zenml
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
#  or implied. See the License for the specific language governing
#  permissions and limitations under the License.

from examples.nlp.training.trainer import UrduTrainer
from zenml.datasources import CSVDatasource
from zenml.exceptions import AlreadyExistsException
from zenml.pipelines import NLPPipeline
from zenml.repo import Repository
from zenml.steps.split import RandomSplit
from zenml.steps.tokenizer import HuggingFaceTokenizerStep

nlp_pipeline = NLPPipeline()

try:
    ds = CSVDatasource(name="My Urdu Text",
                       path="gs://zenml_quickstart/urdu_fake_news.csv")
except AlreadyExistsException:
    ds = Repository.get_instance().get_datasource_by_name(name="My Urdu Text")

nlp_pipeline.add_datasource(ds)

tokenizer_step = HuggingFaceTokenizerStep(text_feature="news",
                                          tokenizer="bert-wordpiece",
                                          vocab_size=3000)

nlp_pipeline.add_tokenizer(tokenizer_step=tokenizer_step)

nlp_pipeline.add_split(RandomSplit(split_map={"train": 0.9, "eval": 0.1}))

nlp_pipeline.add_trainer(
    UrduTrainer(model_name="distilbert-base-uncased",
コード例 #4
0
ファイル: run.py プロジェクト: sjoerdteunisse/zenml
MYSQL_PWD = os.getenv('MYSQL_PWD')
MYSQL_HOST = os.getenv('MYSQL_HOST')
MYSQL_PORT = os.getenv('MYSQL_PORT', '3306')

assert MYSQL_DB
assert MYSQL_USER
assert MYSQL_PWD
assert MYSQL_HOST
assert MYSQL_PORT

# Define the training pipeline
training_pipeline = TrainingPipeline()

# Add a datasource. This will automatically track and version it.
try:
    ds = CSVDatasource(name='Pima Indians Diabetes AWS',
                       path='s3://zenml-quickstart/diabetes.csv')
except AlreadyExistsException:
    ds = Repository.get_instance().get_datasource_by_name(
        'Pima Indians Diabetes')
training_pipeline.add_datasource(ds)

# Add a split
training_pipeline.add_split(RandomSplit(split_map={'train': 0.6, 'eval': 0.4}))

# Add a preprocessing unit
training_pipeline.add_preprocesser(
    StandardPreprocesser(features=[
        'times_pregnant', 'pgc', 'dbp', 'tst', 'insulin', 'bmi', 'pedigree',
        'age'
    ],
                         labels=['has_diabetes'],