def test_check_field_exists(): input_data = ['a', 'b', 'c'] missing = check_field_exists(input_data, ['a', 'd']) assert 'd' in missing missing = check_field_exists(input_data, ['a', 'c']) assert not missing input_data = {'a': 1, 'b': 2, 'c': 3} missing = check_field_exists(input_data, ['a', 'd']) assert 'd' in missing with pytest.raises(ValueError): check_field_exists(111, ['a'])
def run_training_job(): """POST call for initiating retraining of models.""" required_fields = [ "data_version", "bucket_name", "github_repo", "ecosystem" ] input_data = request.get_json() missing_fields = check_field_exists(input_data, required_fields) if missing_fields: raise HTTPError( 400, "These field(s) {} are missing from input " "data".format(missing_fields)) if not input_data: raise HTTPError(400, "Expected JSON request") if type(input_data) != dict: raise HTTPError(400, "Expected dict of input parameters") input_data['environment'] = config.DEPLOYMENT_PREFIX ecosystem = input_data.get('ecosystem') emr_instance = emr_instances.get(ecosystem) if emr_instance: emr_instance = emr_instance() status = emr_instance.run_job(input_data) else: raise HTTPError(400, "Ecosystem {} not supported yet.".format(ecosystem)) return jsonify(status), 200
def construct_job(self, input_dict): """Submit emr job.""" required_fields = [ 'environment', 'data_version', 'bucket_name', 'github_repo' ] missing_fields = check_field_exists(input_dict, required_fields) if missing_fields: logger.error("Missing the parameters in input_dict", extra={"missing_fields": missing_fields}) raise ValueError( "Required fields are missing in the input {}".format( missing_fields)) self.env = input_dict.get('environment') self.data_version = input_dict.get('data_version') github_repo = input_dict.get('github_repo') if not check_url_alive(github_repo): logger.error( "Unable to find the github_repo {}".format(github_repo)) raise ValueError( "Unable to find the github_repo {}".format(github_repo)) self.training_repo_url = github_repo self.hyper_params = input_dict.get('hyper_params', '{}') aws_access_key = os.getenv("AWS_S3_ACCESS_KEY_ID") \ or input_dict.get('aws_access_key') aws_secret_key = os.getenv("AWS_S3_SECRET_ACCESS_KEY")\ or input_dict.get('aws_secret_key') github_token = os.getenv("GITHUB_TOKEN", input_dict.get('github_token')) self.bucket_name = input_dict.get('bucket_name') if self.hyper_params: try: self.hyper_params = json.dumps(input_dict.get('hyper_params'), separators=(',', ':')) except Exception: logger.error( "Invalid hyper params", extra={"hyper_params": input_dict.get('hyper_params')}) self.properties = { 'AWS_S3_ACCESS_KEY_ID': aws_access_key, 'AWS_S3_SECRET_ACCESS_KEY': aws_secret_key, 'AWS_S3_BUCKET_NAME': self.bucket_name, 'MODEL_VERSION': self.data_version, 'DEPLOYMENT_PREFIX': self.env, 'GITHUB_TOKEN': github_token } self.aws_emr = AmazonEmr(aws_access_key_id=aws_access_key, aws_secret_access_key=aws_secret_key) self.aws_emr_client = self.aws_emr.connect() if not self.aws_emr.is_connected(): logger.error("Unable to connect to emr instance.") raise ValueError logger.info("Successfully connected to emr instance.")
def version_details(): """POST call to fetch model details.""" required_fields = ["bucket_name", "ecosystem"] input_data = request.get_json() missing_fields = check_field_exists(input_data, required_fields) if missing_fields: raise HTTPError(400, "These field(s) {} are missing from input " "data".format(missing_fields)) if not input_data: raise HTTPError(400, "Expected JSON request") if type(input_data) != dict: raise HTTPError(400, "Expected dict of input parameters") bucket = input_data['bucket_name'] ecosystem = input_data['ecosystem'] output = trained_model_details(bucket, ecosystem) return output