Beispiel #1
0
    def construct_job(self, input_dict):
        """Submit emr job."""
        required_fields = [
            'environment', 'data_version', 'bucket_name', 'github_repo'
        ]

        missing_fields = check_field_exists(input_dict, required_fields)

        if missing_fields:
            logger.error("Missing the parameters in input_dict",
                         extra={"missing_fields": missing_fields})
            raise ValueError(
                "Required fields are missing in the input {}".format(
                    missing_fields))

        self.env = input_dict.get('environment')
        self.data_version = input_dict.get('data_version')
        github_repo = input_dict.get('github_repo')
        if not check_url_alive(github_repo):
            logger.error(
                "Unable to find the github_repo {}".format(github_repo))
            raise ValueError(
                "Unable to find the github_repo {}".format(github_repo))
        self.training_repo_url = github_repo
        self.hyper_params = input_dict.get('hyper_params', '{}')
        aws_access_key = os.getenv("AWS_S3_ACCESS_KEY_ID") \
            or input_dict.get('aws_access_key')
        aws_secret_key = os.getenv("AWS_S3_SECRET_ACCESS_KEY")\
            or input_dict.get('aws_secret_key')
        github_token = os.getenv("GITHUB_TOKEN",
                                 input_dict.get('github_token'))
        self.bucket_name = input_dict.get('bucket_name')
        if self.hyper_params:
            try:
                self.hyper_params = json.dumps(input_dict.get('hyper_params'),
                                               separators=(',', ':'))
            except Exception:
                logger.error(
                    "Invalid hyper params",
                    extra={"hyper_params": input_dict.get('hyper_params')})

        self.properties = {
            'AWS_S3_ACCESS_KEY_ID': aws_access_key,
            'AWS_S3_SECRET_ACCESS_KEY': aws_secret_key,
            'AWS_S3_BUCKET_NAME': self.bucket_name,
            'MODEL_VERSION': self.data_version,
            'DEPLOYMENT_PREFIX': self.env,
            'GITHUB_TOKEN': github_token
        }

        self.aws_emr = AmazonEmr(aws_access_key_id=aws_access_key,
                                 aws_secret_access_key=aws_secret_key)

        self.aws_emr_client = self.aws_emr.connect()

        if not self.aws_emr.is_connected():
            logger.error("Unable to connect to emr instance.")
            raise ValueError

        logger.info("Successfully connected to emr instance.")
def get_training_file_url(user, repo, branch='master', training_file_path='training/train.py'):
    """Get the training file from the github repo."""
    if not user and not repo:
        logger.error("Please provide the github user and repo",
                     extra={"user": user, "repo": repo})
        raise ValueError("Please provide the github user:{} and repo:{}"
                         .format(user, repo))

    file_url = urljoin(GITHUB_CONTENT_BASEURL,
                       '/'.join((user, repo, branch,
                                 training_file_path)))

    if not check_url_alive(file_url):
        logger.error("unable to reach the github training file path",
                     extra={'github_url': file_url})
        raise ValueError("Could not able to fetch training file")
    return file_url
Beispiel #3
0
def test_check_url_alive():
    url = 'https://google.com'
    assert check_url_alive(url)
    url = 'https://234j23ksadasca.com'
    assert not check_url_alive(url)