def test_run_training_cloud_1(self, mocker): mp_init = mocker.patch('imageatm.components.cloud.AWS.init') mp_apply = mocker.patch('imageatm.components.cloud.AWS.apply') mp_train = mocker.patch('imageatm.components.cloud.AWS.train') mp_destroy = mocker.patch('imageatm.components.cloud.AWS.destroy') mocker.patch('imageatm.components.cloud.AWS.__init__') AWS.__init__.return_value = None run_training_cloud( image_dir=TEST_IMAGE_DIR, job_dir=TEST_JOB_DIR, provider='aws', tf_dir=TEST_TF_DIR, region=TEST_REGION, instance_type=TEST_INSTANCE_TYPE, vpc_id=TEST_VPC_ID, bucket=TEST_S3_BUCKET, destroy=False, cloud_tag=TEST_CLOUD_TAG, ) mp_init.assert_called_once() mp_apply.assert_called_once() mp_train.assert_called_with(job_dir=TEST_JOB_DIR, image_dir=TEST_IMAGE_DIR) mp_destroy.assert_not_called() AWS.__init__.assert_called_with( tf_dir=TEST_TF_DIR, region=TEST_REGION, instance_type=TEST_INSTANCE_TYPE, vpc_id=TEST_VPC_ID, s3_bucket=TEST_S3_BUCKET, job_dir=TEST_JOB_DIR, cloud_tag=TEST_CLOUD_TAG, )
def train( config, config_file: Optional[Path] = None, job_dir: Optional[Path] = None, image_dir: Optional[Path] = None, provider: Optional[str] = None, instance_type: Optional[str] = None, region: Optional[str] = None, vpc_id: Optional[str] = None, bucket: Optional[str] = None, tf_dir: Optional[Path] = None, train_cloud: Optional[bool] = None, destroy: Optional[bool] = None, batch_size: Optional[int] = None, learning_rate_dense: Optional[float] = None, learning_rate_all: Optional[float] = None, epochs_train_dense: Optional[int] = None, epochs_train_all: Optional[int] = None, base_model_name: Optional[str] = None, cloud_tag: Optional[str] = None, ): config = update_config( config=config, config_file=config_file, job_dir=job_dir, image_dir=image_dir, provider=provider, instance_type=instance_type, region=region, vpc_id=vpc_id, bucket=bucket, tf_dir=tf_dir, train_cloud=train_cloud, destroy=destroy, batch_size=batch_size, learning_rate_dense=learning_rate_dense, learning_rate_all=learning_rate_all, epochs_train_dense=epochs_train_dense, epochs_train_all=epochs_train_all, base_model_name=base_model_name, cloud_tag=cloud_tag, ) config.train['run'] = True validate_config(config, ['train']) if config.train.get('cloud'): from imageatm.scripts import run_training_cloud run_training_cloud(**{**config.cloud, **config.train}) else: from imageatm.scripts import run_training run_training(**config.train)
def test_run_training_cloud_2(self, mocker): mp_init = mocker.patch('imageatm.components.cloud.AWS.init') mp_apply = mocker.patch('imageatm.components.cloud.AWS.apply') mp_train = mocker.patch('imageatm.components.cloud.AWS.train') mp_destroy = mocker.patch('imageatm.components.cloud.AWS.destroy') mocker.patch('imageatm.components.cloud.AWS.__init__') AWS.__init__.return_value = None run_training_cloud( image_dir=TEST_IMAGE_DIR, job_dir=TEST_JOB_DIR, provider='aws', tf_dir=TEST_TF_DIR, region=TEST_REGION, instance_type=TEST_INSTANCE_TYPE, vpc_id=TEST_VPC_ID, bucket=TEST_S3_BUCKET, destroy=True, cloud_tag=TEST_CLOUD_TAG, epochs_train_dense=TEST_EPOCHS_TRAIN_DENSE, epochs_train_all=TEST_EPOCHS_TRAIN_ALL, learning_rate_dense=TEST_LEARNING_RATE_DENSE, learning_rate_all=TEST_LEARNING_RATE_ALL, batch_size=TEST_BATCH_SIZE, dropout_rate=TEST_DROPOUT_RATE, base_model_name=TEST_BASE_MODEL_NAME, loss=TEST_LOSS, ) mp_init.assert_called_once() mp_apply.assert_called_once() mp_train.assert_called_with( job_dir=TEST_JOB_DIR, image_dir=TEST_IMAGE_DIR, epochs_train_dense=TEST_EPOCHS_TRAIN_DENSE, epochs_train_all=TEST_EPOCHS_TRAIN_ALL, learning_rate_dense=TEST_LEARNING_RATE_DENSE, learning_rate_all=TEST_LEARNING_RATE_ALL, batch_size=TEST_BATCH_SIZE, dropout_rate=TEST_DROPOUT_RATE, base_model_name=TEST_BASE_MODEL_NAME, loss=TEST_LOSS, ) mp_destroy.assert_called_once() AWS.__init__.assert_called_with( tf_dir=TEST_TF_DIR, region=TEST_REGION, instance_type=TEST_INSTANCE_TYPE, vpc_id=TEST_VPC_ID, s3_bucket=TEST_S3_BUCKET, job_dir=TEST_JOB_DIR, cloud_tag=TEST_CLOUD_TAG, )
def pipeline( config: Config, config_file: Path, job_dir: Optional[Path] = None, image_dir: Optional[Path] = None, samples_file: Optional[Path] = None, provider: Optional[str] = None, instance_type: Optional[str] = None, region: Optional[str] = None, vpc_id: Optional[str] = None, bucket: Optional[str] = None, tf_dir: Optional[Path] = None, train_cloud: Optional[bool] = None, destroy: Optional[bool] = None, resize: Optional[bool] = None, batch_size: Optional[int] = None, learning_rate_dense: Optional[float] = None, learning_rate_all: Optional[float] = None, epochs_train_dense: Optional[int] = None, epochs_train_all: Optional[int] = None, base_model_name: Optional[str] = None, cloud_tag: Optional[str] = None, create_report: Optional[bool] = None, kernel_name: Optional[str] = None, export_html: Optional[bool] = None, export_pdf: Optional[bool] = None, ): """Runs the entire pipeline based on config file.""" config = update_config( config=config, config_file=config_file, job_dir=job_dir, image_dir=image_dir, samples_file=samples_file, provider=provider, instance_type=instance_type, region=region, vpc_id=vpc_id, bucket=bucket, tf_dir=tf_dir, train_cloud=train_cloud, destroy=destroy, resize=resize, batch_size=batch_size, learning_rate_dense=learning_rate_dense, learning_rate_all=learning_rate_all, epochs_train_dense=epochs_train_dense, epochs_train_all=epochs_train_all, base_model_name=base_model_name, cloud_tag=cloud_tag, create_report=create_report, kernel_name=kernel_name, export_html=export_html, export_pdf=export_pdf, ) validate_config(config, config.pipeline) Path(config.job_dir).resolve().mkdir(parents=True, exist_ok=True) logger = get_logger(__name__, config.job_dir) # type: ignore if 'dataprep' in config.pipeline: from imageatm.scripts import run_dataprep logger.info('\n********************************\n' '******* Data preparation *******\n' '********************************') dp = run_dataprep(**config.dataprep) # update image_dir if images were resized if config.dataprep.get('resize', False): config.image_dir = dp.image_dir # type: ignore config = update_component_configs(config) if 'train' in config.pipeline: logger.info('\n********************************\n' '*********** Training ***********\n' '********************************') if config.train.get('cloud'): from imageatm.scripts import run_training_cloud run_training_cloud(**{**config.cloud, **config.train}) else: from imageatm.scripts import run_training run_training(**config.train) if 'evaluate' in config.pipeline: from imageatm.scripts import run_evaluation logger.info('\n********************************\n' '********** Evaluation **********\n' '********************************') run_evaluation(**config.evaluate) if 'cloud' in config.pipeline: from imageatm.scripts import run_cloud run_cloud(**config.cloud)