コード例 #1
0
def export_saved_model(checkpoint_dir, model_path, s3=boto3.client('s3')):
    if checkpoint_dir.startswith('s3://'):
        bucket_name, key_prefix = cs.parse_s3_url(checkpoint_dir)
        prefix = os.path.join(key_prefix, 'export', 'Servo')

        try:
            contents = s3.list_objects_v2(Bucket=bucket_name,
                                          Prefix=prefix)["Contents"]
            saved_model_path_array = [
                x['Key'] for x in contents
                if x['Key'].endswith('saved_model.pb')
            ]

            if len(saved_model_path_array) == 0:
                logger.info(
                    "Failed to download saved model. File does not exist in {}"
                    .format(checkpoint_dir))
                return
        except KeyError as e:
            logger.error(
                "Failed to download saved model. File does not exist in {}".
                format(checkpoint_dir))
            raise e

        saved_model_path = saved_model_path_array[0]

        variables_path = [
            x['Key'] for x in contents if 'variables/variables' in x['Key']
        ]
        variable_names_to_paths = {
            v.split('/').pop(): v
            for v in variables_path
        }

        prefixes = key_prefix.split('/')
        folders = saved_model_path.split('/')[len(prefixes):]
        saved_model_filename = folders.pop()
        path_to_save_model = os.path.join(model_path, *folders)

        path_to_variables = os.path.join(path_to_save_model, 'variables')

        os.makedirs(path_to_variables)

        target = os.path.join(path_to_save_model, saved_model_filename)
        s3.download_file(bucket_name, saved_model_path, target)
        logger.info("Downloaded saved model at {}".format(target))

        for filename, full_path in variable_names_to_paths.items():
            key = full_path
            target = os.path.join(path_to_variables, filename)
            s3.download_file(bucket_name, key, target)
    else:
        if os.path.exists(checkpoint_dir):
            shutil.copy2(checkpoint_dir, model_path)
        else:
            logger.error(
                "Failed to copy saved model. File does not exist in {}".format(
                    checkpoint_dir))
コード例 #2
0
    def _configure_s3_file_system(self):
        # loads S3 filesystem plugin
        s3 = boto3.client('s3')

        bucket_name, key = parse_s3_url(self.model_path)

        bucket_location = s3.get_bucket_location(Bucket=bucket_name)['LocationConstraint']

        if bucket_location:
            os.environ['S3_REGION'] = bucket_location
        os.environ['S3_USE_HTTPS'] = "1"
コード例 #3
0
def export_saved_model(checkpoint_dir,
                       model_path,
                       s3=boto3.client(
                           's3', region_name=os.environ.get('AWS_REGION'))):
    if checkpoint_dir.startswith('s3://'):
        bucket_name, key_prefix = cs.parse_s3_url(checkpoint_dir)
        prefix = os.path.join(key_prefix, 'export', 'Servo')

        try:
            contents = s3.list_objects_v2(Bucket=bucket_name,
                                          Prefix=prefix)["Contents"]
            saved_model_path_array = [
                x['Key'] for x in contents
                if x['Key'].endswith('saved_model.pb')
            ]

            if len(saved_model_path_array) == 0:
                logger.info(
                    "Failed to download saved model. File does not exist in {}"
                    .format(checkpoint_dir))
                return
        except KeyError as e:
            logger.error(
                "Failed to download saved model. File does not exist in {}".
                format(checkpoint_dir))
            raise e
        # Select most recent saved_model.pb
        saved_model_path = saved_model_path_array[-1]
        saved_model_base_path = os.path.dirname(saved_model_path)

        prefixes = key_prefix.split('/')
        folders = saved_model_path.split('/')[len(prefixes):-1]
        path_to_save_model = os.path.join(model_path, *folders)

        def file_filter(x):
            return x['Key'].startswith(
                saved_model_base_path) and not x['Key'].endswith("/")

        paths_to_copy = [x['Key'] for x in contents if file_filter(x)]

        for key in paths_to_copy:
            target = re.sub(r"^" + saved_model_base_path, path_to_save_model,
                            key)
            _makedirs_for_file(target)
            s3.download_file(bucket_name, key, target)
        logger.info("Downloaded saved model at {}".format(path_to_save_model))
    else:
        if os.path.exists(checkpoint_dir):
            _recursive_copy(checkpoint_dir, model_path)
        else:
            logger.error(
                "Failed to copy saved model. File does not exist in {}".format(
                    checkpoint_dir))
コード例 #4
0
def configure_s3_fs(checkpoint_path):
    # If env variable is not set, defaults to None, which will use the global endpoint.
    region_name = os.environ.get('AWS_REGION')
    s3 = boto3.client('s3', region_name=region_name)

    # We get the AWS region of the checkpoint bucket, which may be different from
    # the region this container is currently running in.
    bucket_name, key = cs.parse_s3_url(checkpoint_path)
    bucket_location = s3.get_bucket_location(Bucket=bucket_name)['LocationConstraint']

    # Configure environment variables used by TensorFlow S3 file system
    if bucket_location:
        os.environ['S3_REGION'] = bucket_location
    os.environ['S3_USE_HTTPS'] = '1'
コード例 #5
0
def configure_s3_fs(checkpoint_path):
    # If env variable is not set, defaults to None, which will use the global endpoint.
    region_name = os.environ.get('AWS_REGION')
    s3 = boto3.client('s3', region_name=region_name)

    # We get the AWS region of the checkpoint bucket, which may be different from
    # the region this container is currently running in.
    bucket_name, key = cs.parse_s3_url(checkpoint_path)
    bucket_location = s3.get_bucket_location(
        Bucket=bucket_name)['LocationConstraint']

    # Configure environment variables used by TensorFlow S3 file system
    if bucket_location:
        os.environ['S3_REGION'] = bucket_location
    os.environ['S3_USE_HTTPS'] = '1'
コード例 #6
0
def export_saved_model(checkpoint_dir, model_path, s3=boto3.client('s3')):
    if checkpoint_dir.startswith('s3://'):
        bucket_name, key_prefix = cs.parse_s3_url(checkpoint_dir)
        prefix = os.path.join(key_prefix, 'export', 'Servo')

        try:
            contents = s3.list_objects_v2(Bucket=bucket_name, Prefix=prefix)["Contents"]
            saved_model_path_array = [x['Key'] for x in contents if x['Key'].endswith('saved_model.pb')]

            if len(saved_model_path_array) == 0:
                logger.info("Failed to download saved model. File does not exist in {}".format(checkpoint_dir))
                return
        except KeyError as e:
            logger.error("Failed to download saved model. File does not exist in {}".format(checkpoint_dir))
            raise e
        # Select most recent saved_model.pb
        saved_model_path = saved_model_path_array[-1]

        variables_path = [x['Key'] for x in contents if 'variables/variables' in x['Key']]
        variable_names_to_paths = {v.split('/').pop(): v for v in variables_path}

        prefixes = key_prefix.split('/')
        folders = saved_model_path.split('/')[len(prefixes):]
        saved_model_filename = folders.pop()
        path_to_save_model = os.path.join(model_path, *folders)

        path_to_variables = os.path.join(path_to_save_model, 'variables')

        os.makedirs(path_to_variables)

        target = os.path.join(path_to_save_model, saved_model_filename)
        s3.download_file(bucket_name, saved_model_path, target)
        logger.info("Downloaded saved model at {}".format(target))

        for filename, full_path in variable_names_to_paths.items():
            key = full_path
            target = os.path.join(path_to_variables, filename)
            s3.download_file(bucket_name, key, target)
    else:
        if os.path.exists(checkpoint_dir):
            _recursive_copy(checkpoint_dir, model_path)
        else:
            logger.error("Failed to copy saved model. File does not exist in {}".format(checkpoint_dir))
コード例 #7
0
def export_saved_model(checkpoint_dir, model_path, s3=boto3.client('s3', region_name=os.environ.get('AWS_REGION'))):
    if checkpoint_dir.startswith('s3://'):
        bucket_name, key_prefix = cs.parse_s3_url(checkpoint_dir)
        prefix = os.path.join(key_prefix, 'export', 'Servo')

        try:
            contents = s3.list_objects_v2(Bucket=bucket_name, Prefix=prefix)["Contents"]
            saved_model_path_array = [x['Key'] for x in contents if x['Key'].endswith('saved_model.pb')]

            if len(saved_model_path_array) == 0:
                logger.info("Failed to download saved model. File does not exist in {}".format(checkpoint_dir))
                return
        except KeyError as e:
            logger.error("Failed to download saved model. File does not exist in {}".format(checkpoint_dir))
            raise e
        # Select most recent saved_model.pb
        saved_model_path = saved_model_path_array[-1]
        saved_model_base_path = os.path.dirname(saved_model_path)

        prefixes = key_prefix.split('/')
        folders = saved_model_path.split('/')[len(prefixes):-1]
        path_to_save_model = os.path.join(model_path, *folders)

        def file_filter(x): return x['Key'].startswith(saved_model_base_path) and not x['Key'].endswith("/")
        paths_to_copy = [x['Key'] for x in contents if file_filter(x)]

        for key in paths_to_copy:
            target = re.sub(r"^"+saved_model_base_path, path_to_save_model, key)
            _makedirs_for_file(target)
            s3.download_file(bucket_name, key, target)
        logger.info("Downloaded saved model at {}".format(path_to_save_model))
    else:
        if os.path.exists(checkpoint_dir):
            _recursive_copy(checkpoint_dir, model_path)
        else:
            logger.error("Failed to copy saved model. File does not exist in {}".format(checkpoint_dir))
コード例 #8
0
def test_parse_s3_url_no_key():
    assert ("bucket", "") == cs.parse_s3_url("s3://bucket/")
コード例 #9
0
def test_parse_s3_url():
    assert ("bucket", "key") == cs.parse_s3_url("s3://bucket/key")
コード例 #10
0
def test_parse_s3_url_invalid():
    with pytest.raises(ValueError):
        cs.parse_s3_url("nots3://blah/blah")