コード例 #1
0
def evaluate(
    config: Config,
    config_file: Optional[Path] = None,
    image_dir: Optional[Path] = None,
    job_dir: Optional[Path] = None,
    create_report: Optional[bool] = None,
    kernel_name: Optional[str] = None,
    export_html: Optional[bool] = None,
    export_pdf: Optional[bool] = None,
):
    config = update_config(
        config=config,
        config_file=config_file,
        job_dir=job_dir,
        image_dir=image_dir,
        create_report=create_report,
        kernel_name=kernel_name,
        export_html=export_html,
        export_pdf=export_pdf,
    )

    config.evaluate['run'] = True
    validate_config(config, ['evaluate'])

    from imageatm.scripts import run_evaluation

    run_evaluation(**config.evaluate)
コード例 #2
0
def train(
    config,
    config_file: Optional[Path] = None,
    job_dir: Optional[Path] = None,
    image_dir: Optional[Path] = None,
    provider: Optional[str] = None,
    instance_type: Optional[str] = None,
    region: Optional[str] = None,
    vpc_id: Optional[str] = None,
    bucket: Optional[str] = None,
    tf_dir: Optional[Path] = None,
    train_cloud: Optional[bool] = None,
    destroy: Optional[bool] = None,
    batch_size: Optional[int] = None,
    learning_rate_dense: Optional[float] = None,
    learning_rate_all: Optional[float] = None,
    epochs_train_dense: Optional[int] = None,
    epochs_train_all: Optional[int] = None,
    base_model_name: Optional[str] = None,
    cloud_tag: Optional[str] = None,
):
    config = update_config(
        config=config,
        config_file=config_file,
        job_dir=job_dir,
        image_dir=image_dir,
        provider=provider,
        instance_type=instance_type,
        region=region,
        vpc_id=vpc_id,
        bucket=bucket,
        tf_dir=tf_dir,
        train_cloud=train_cloud,
        destroy=destroy,
        batch_size=batch_size,
        learning_rate_dense=learning_rate_dense,
        learning_rate_all=learning_rate_all,
        epochs_train_dense=epochs_train_dense,
        epochs_train_all=epochs_train_all,
        base_model_name=base_model_name,
        cloud_tag=cloud_tag,
    )

    config.train['run'] = True

    validate_config(config, ['train'])

    if config.train.get('cloud'):
        from imageatm.scripts import run_training_cloud

        run_training_cloud(**{**config.cloud, **config.train})
    else:
        from imageatm.scripts import run_training

        run_training(**config.train)
コード例 #3
0
ファイル: commands.py プロジェクト: wangshoujun20/imageatm
def evaluate(
    config: Config,
    config_file: Optional[Path] = None,
    image_dir: Optional[Path] = None,
    job_dir: Optional[Path] = None,
):
    config = update_config(config=config,
                           config_file=config_file,
                           job_dir=job_dir,
                           image_dir=image_dir)

    config.evaluate['run'] = True
    validate_config(config, ['evaluate'])

    from imageatm.scripts import run_evaluation

    run_evaluation(**config.evaluate)
コード例 #4
0
def dataprep(
    config: Config,
    config_file: Optional[Path] = None,
    image_dir: Optional[Path] = None,
    samples_file: Optional[Path] = None,
    job_dir: Optional[Path] = None,
    resize: Optional[bool] = None,
):
    config = update_config(
        config=config,
        config_file=config_file,
        job_dir=job_dir,
        image_dir=image_dir,
        samples_file=samples_file,
        resize=resize,
    )

    config.dataprep['run'] = True
    validate_config(config, ['dataprep'])

    from imageatm.scripts import run_dataprep

    run_dataprep(**config.dataprep)
コード例 #5
0
def cloud(
    config,
    job_dir: Optional[Path] = None,
    config_file: Optional[Path] = None,
    provider: Optional[str] = None,
    instance_type: Optional[str] = None,
    region: Optional[str] = None,
    vpc_id: Optional[str] = None,
    bucket: Optional[str] = None,
    tf_dir: Optional[Path] = None,
    train_cloud: Optional[bool] = None,
    destroy: Optional[bool] = None,
    no_destroy: Optional[bool] = None,
    cloud_tag: Optional[str] = None,
):
    config = update_config(
        config=config,
        job_dir=job_dir,
        config_file=config_file,
        provider=provider,
        instance_type=instance_type,
        region=region,
        vpc_id=vpc_id,
        bucket=bucket,
        tf_dir=tf_dir,
        train_cloud=train_cloud,
        destroy=destroy,
        no_destroy=no_destroy,
        cloud_tag=cloud_tag,
    )

    config.cloud['run'] = True
    validate_config(config, ['cloud'])

    from imageatm.scripts import run_cloud

    run_cloud(**config.cloud)
コード例 #6
0
def pipeline(
    config: Config,
    config_file: Path,
    job_dir: Optional[Path] = None,
    image_dir: Optional[Path] = None,
    samples_file: Optional[Path] = None,
    provider: Optional[str] = None,
    instance_type: Optional[str] = None,
    region: Optional[str] = None,
    vpc_id: Optional[str] = None,
    bucket: Optional[str] = None,
    tf_dir: Optional[Path] = None,
    train_cloud: Optional[bool] = None,
    destroy: Optional[bool] = None,
    resize: Optional[bool] = None,
    batch_size: Optional[int] = None,
    learning_rate_dense: Optional[float] = None,
    learning_rate_all: Optional[float] = None,
    epochs_train_dense: Optional[int] = None,
    epochs_train_all: Optional[int] = None,
    base_model_name: Optional[str] = None,
    cloud_tag: Optional[str] = None,
    create_report: Optional[bool] = None,
    kernel_name: Optional[str] = None,
    export_html: Optional[bool] = None,
    export_pdf: Optional[bool] = None,
):
    """Runs the entire pipeline based on config file."""
    config = update_config(
        config=config,
        config_file=config_file,
        job_dir=job_dir,
        image_dir=image_dir,
        samples_file=samples_file,
        provider=provider,
        instance_type=instance_type,
        region=region,
        vpc_id=vpc_id,
        bucket=bucket,
        tf_dir=tf_dir,
        train_cloud=train_cloud,
        destroy=destroy,
        resize=resize,
        batch_size=batch_size,
        learning_rate_dense=learning_rate_dense,
        learning_rate_all=learning_rate_all,
        epochs_train_dense=epochs_train_dense,
        epochs_train_all=epochs_train_all,
        base_model_name=base_model_name,
        cloud_tag=cloud_tag,
        create_report=create_report,
        kernel_name=kernel_name,
        export_html=export_html,
        export_pdf=export_pdf,
    )

    validate_config(config, config.pipeline)

    Path(config.job_dir).resolve().mkdir(parents=True, exist_ok=True)

    logger = get_logger(__name__, config.job_dir)  # type: ignore

    if 'dataprep' in config.pipeline:
        from imageatm.scripts import run_dataprep

        logger.info('\n********************************\n'
                    '******* Data preparation *******\n'
                    '********************************')

        dp = run_dataprep(**config.dataprep)

        # update image_dir if images were resized
        if config.dataprep.get('resize', False):
            config.image_dir = dp.image_dir  # type: ignore
            config = update_component_configs(config)

    if 'train' in config.pipeline:
        logger.info('\n********************************\n'
                    '*********** Training ***********\n'
                    '********************************')

        if config.train.get('cloud'):
            from imageatm.scripts import run_training_cloud

            run_training_cloud(**{**config.cloud, **config.train})
        else:
            from imageatm.scripts import run_training

            run_training(**config.train)

    if 'evaluate' in config.pipeline:
        from imageatm.scripts import run_evaluation

        logger.info('\n********************************\n'
                    '********** Evaluation **********\n'
                    '********************************')

        run_evaluation(**config.evaluate)

    if 'cloud' in config.pipeline:
        from imageatm.scripts import run_cloud

        run_cloud(**config.cloud)
コード例 #7
0
def test_update_config():
    # check that defaults are being set
    config = Config()

    result = update_config(config)

    assert result.train == {'cloud': False}
    assert result.data_prep == {'resize': False}
    assert result.cloud == {}
    assert result.evaluate == {}

    # check that defaults, image_dir, and job_dir are being set
    config = Config()
    config.image_dir = 'test_image'
    config.job_dir = 'test_job'

    result = update_config(config)

    assert result.train == {
        'cloud': False,
        'image_dir': 'test_image',
        'job_dir': 'test_job'
    }
    assert result.data_prep == {
        'resize': False,
        'image_dir': 'test_image',
        'job_dir': 'test_job'
    }
    assert result.cloud == {'job_dir': 'test_job'}
    assert result.evaluate == {
        'image_dir': 'test_image',
        'job_dir': 'test_job'
    }

    # check that config file gets populated correctly
    TEST_CONFIG_FILE = p.resolve().parent / 'test_configs' / 'config_train.yml'

    config = Config()

    result = update_config(config, config_file=TEST_CONFIG_FILE)

    assert result.train == {
        'run': True,
        'cloud': False,
        'image_dir': 'test_train/images',
        'job_dir': 'test_train/job_dir',
    }
    assert result.data_prep == {
        'run': False,
        'resize': True,
        'image_dir': 'test_data_prep/images',
        'job_dir': 'test_data_prep/job_dir',
        'samples_file': 'test_data_prep/samples.json',
    }
    assert result.cloud == {
        'run': False,
        'provider': 'aws',  # supported providers ['aws']
        'tf_dir': 'cloud/aws',
        'region':
        'eu-west-1',  # supported regions ['eu-west-1', 'eu-central-1']
        'vpc_id': 'abc',
        'instance_type': 't2.micro',  # supported instances ['p2.xlarge']
        'bucket':
        's3://test_bucket',  # s3 bucket needs to exist, will not be created/destroyed by terraform
        'destroy': True,
        'cloud_tag': 'test_user',
    }
    assert result.evaluate == {
        'run': False,
        'image_dir': 'test_evaluate/images',
        'job_dir': 'test_evaluate/job_dir',
    }

    # check that config file gets populated correctly and image and job dir are updated
    TEST_CONFIG_FILE = p.resolve().parent / 'test_configs' / 'config_train.yml'

    config = Config()
    config.image_dir = 'test_image'
    config.job_dir = 'test_job'

    result = update_config(config, config_file=TEST_CONFIG_FILE)

    print(result.cloud)

    assert result.train == {
        'run': True,
        'cloud': False,
        'image_dir': 'test_image',
        'job_dir': 'test_job',
    }
    assert result.data_prep == {
        'run': False,
        'resize': True,
        'image_dir': 'test_image',
        'job_dir': 'test_job',
        'samples_file': 'test_data_prep/samples.json',
    }
    assert result.cloud == {
        'run': False,
        'provider': 'aws',
        'tf_dir': 'cloud/aws',
        'region': 'eu-west-1',
        'vpc_id': 'abc',
        'instance_type': 't2.micro',
        'bucket': 's3://test_bucket',
        'destroy': True,
        'job_dir': 'test_job',
        'cloud_tag': 'test_user',
    }
    assert result.evaluate == {
        'run': False,
        'image_dir': 'test_image',
        'job_dir': 'test_job'
    }

    # test that options overwrite config file
    TEST_CONFIG_FILE = p.resolve().parent / 'test_configs' / 'config_train.yml'

    config = Config()

    result = update_config(
        config,
        config_file=TEST_CONFIG_FILE,
        image_dir='test_image',
        job_dir='test_job',
        region='eu-central-1',
    )

    assert result.train == {
        'run': True,
        'cloud': False,
        'image_dir': 'test_image',
        'job_dir': 'test_job',
    }

    assert result.data_prep == {
        'run': False,
        'resize': True,
        'image_dir': 'test_image',
        'job_dir': 'test_job',
        'samples_file': 'test_data_prep/samples.json',
    }

    assert result.cloud == {
        'run': False,
        'provider': 'aws',
        'tf_dir': 'cloud/aws',
        'region': 'eu-central-1',
        'vpc_id': 'abc',
        'instance_type': 't2.micro',
        'bucket': 's3://test_bucket',
        'destroy': True,
        'bucket': 's3://test_bucket',
        'job_dir': 'test_job',
        'cloud_tag': 'test_user',
    }

    assert result.evaluate == {
        'run': False,
        'image_dir': 'test_image',
        'job_dir': 'test_job'
    }