예제 #1
0
def get_project_or_local(project=None, is_cli: bool = False):
    from polyaxon import settings

    if not project and not ProjectConfigManager.is_initialized():
        if is_cli:
            Printer.print_error("Please provide a valid project.")
            sys.exit(1)
        else:
            raise PolyaxonClientException("Please provide a valid project.")

    if project:
        owner, project_name = get_project_info(project)
    else:
        project = ProjectConfigManager.get_config()
        owner, project_name = project.owner, project.name

    if not owner and (not settings.CLI_CONFIG or settings.CLI_CONFIG.is_ce):
        owner = DEFAULT

    if not all([owner, project_name]):
        if is_cli:
            Printer.print_error("Please provide a valid project.")
            sys.exit(1)
        else:
            raise PolyaxonClientException("Please provide a valid project.")
    return owner, project_name
예제 #2
0
파일: run.py 프로젝트: smilee/polyaxon
    def __init__(
        self,
        owner: str = None,
        project: str = None,
        run_uuid: str = None,
        client: PolyaxonClient = None,
    ):

        try:
            owner, project = get_project_or_local(
                get_project_full_name(owner=owner, project=project))
        except PolyaxonClientException:
            pass

        if project is None:
            if settings.CLIENT_CONFIG.is_managed:
                owner, project, _run_uuid = get_run_info()
                run_uuid = run_uuid or _run_uuid
            else:
                raise PolyaxonClientException(
                    "Please provide a valid project.")

        if not owner or not project:
            raise PolyaxonClientException(
                "Please provide a valid project with owner.")

        self.client = client
        if not (self.client or settings.CLIENT_CONFIG.is_offline):
            self.client = PolyaxonClient()

        self._owner = owner
        self._project = project
        self._run_uuid = get_run_or_local(run_uuid)
        self._run_data = polyaxon_sdk.V1Run()
        self._namespace = None
예제 #3
0
def get_model_info(entity: str, entity_name: str, is_cli: bool = False):
    from polyaxon import settings

    if not entity:
        message = "Please provide a valid {}!".format(entity_name)
        if is_cli:
            Printer.print_error(message)
            sys.exit(1)
        else:
            raise PolyaxonClientException(message)

    owner = get_local_owner(is_cli=is_cli)

    if not owner and (not settings.CLI_CONFIG or settings.CLI_CONFIG.is_ce):
        owner = DEFAULT

    owner, entity_namespace, version = get_versioned_entity_info(
        entity=entity, entity_name=entity_name, default_owner=owner
    )

    owner = owner or settings.AUTH_CONFIG.username

    if not all([owner, entity_name]):
        message = "Please provide a valid {}.".format(entity_name)
        if is_cli:
            Printer.print_error(message)
            sys.exit(1)
        else:
            raise PolyaxonClientException(message)
    return owner, entity_namespace, version
예제 #4
0
def get_project_or_local(project=None, is_cli: bool = False):
    from polyaxon import settings

    if not project and not ProjectConfigManager.is_initialized():
        error_message = "Please provide a valid project or initialize a project in the current path."
        if is_cli:
            Printer.print_error(error_message)
            sys.exit(1)
        else:
            raise PolyaxonClientException(error_message)

    if project:
        owner, project_name = get_entity_info(project)
    else:
        project = get_local_project()

        owner, project_name = project.owner, project.name

    if not owner:
        owner = get_local_owner(is_cli=is_cli)

    if not owner and (not settings.CLI_CONFIG or settings.CLI_CONFIG.is_ce):
        owner = DEFAULT

    if not all([owner, project_name]):
        error_message = get_project_error_message(owner, project_name)
        if is_cli:
            Printer.print_error(error_message)
            sys.exit(1)
        else:
            raise PolyaxonClientException(error_message)
    return owner, project_name
예제 #5
0
파일: run.py 프로젝트: klonggan/polyaxon
    def create(
        self,
        name=None,
        tags=None,
        description=None,
        content=None,
        base_outputs_path=None,
    ):
        run = polyaxon_sdk.V1Run()
        if self.track_env:
            run.run_env = get_run_env()
        if name:
            run.name = name
        if tags:
            run.tags = tags
        if description:
            run.description = description
        if content:
            try:
                specification = get_specification(data=[content])
            except Exception as e:
                raise PolyaxonClientException(e)
            run.content = specification.config_dump
        else:
            run.is_managed = False

        if self.client:
            try:
                run = self.client.runs_v1.create_run(owner=self.owner,
                                                     project=self.project,
                                                     body=run)
            except (ApiException, HTTPError) as e:
                raise PolyaxonClientException(e)
            if not run:
                raise PolyaxonClientException("Could not create a run.")
        if not settings.CLIENT_CONFIG.is_managed and self.track_logs:
            setup_logging(send_logs=self.send_logs)
        self._run = run
        self._run_uuid = run.uuid
        self.status = "created"

        # Setup the outputs store
        if self.outputs_store is None and base_outputs_path:
            outputs_path = "{}/{}/{}/{}".format(base_outputs_path, self.owner,
                                                self.project, self.run_uuid)
            self.set_outputs_store(outputs_path=outputs_path)

        if self.track_code:
            self.log_code_ref()

        if not settings.CLIENT_CONFIG.is_managed:
            self._start()
        else:
            self._register_wait()

        return self
예제 #6
0
파일: run.py 프로젝트: gregmbi/polyaxon
    def create(self, name=None, tags=None, description=None, content=None):
        operation = polyaxon_sdk.V1OperationBody()
        if name:
            operation.name = name
        if tags:
            operation.tags = tags
        if description:
            operation.description = description
        if content:
            try:
                specification = OperationSpecification.read(content)
            except Exception as e:
                raise PolyaxonClientException("Client error: %s" % e) from e
            operation.content = specification.to_dict(dump=True)
        else:
            operation.is_managed = False

        if self.client:
            try:
                run = self.client.runs_v1.create_run(owner=self.owner,
                                                     project=self.project,
                                                     body=operation)
            except (ApiException, HTTPError) as e:
                raise PolyaxonClientException("Client error: %s" % e) from e
            if not run:
                raise PolyaxonClientException("Could not create a run.")
        else:
            run = polyaxon_sdk.V1Run(
                name=operation.name,
                tags=operation.tags,
                description=operation.description,
                content=operation.content,
                is_managed=operation.is_managed,
            )

        self._run = run
        self._run_uuid = run.uuid

        if self.artifacts_path:
            self.set_run_event_logger()

        if self.track_code:
            self.log_code_ref()
        if self.track_env:
            self.log_run_env()

        if not settings.CLIENT_CONFIG.is_managed:
            self._start()
        else:
            self._register_wait()

        return self
예제 #7
0
def get_run_info(run_instance: str = None):
    run_instance = run_instance or os.getenv(POLYAXON_KEYS_RUN_INSTANCE, None)
    if not run_instance:
        raise PolyaxonClientException(
            "Could not get run info, "
            "please make sure this is run is correctly started by Polyaxon.")

    parts = run_instance.split(".")
    if not len(parts) == 4:
        raise PolyaxonClientException(
            "run instance is invalid `{}`, "
            "please make sure this is run is correctly started by Polyaxon.".
            format(run_instance))
    return parts[0], parts[1], parts[-1]
예제 #8
0
파일: run.py 프로젝트: dorucioclea/polyaxon
 def create(
     self,
     name: str = None,
     description: str = None,
     tags: Union[str, Sequence[str]] = None,
     content: Union[str, V1Operation] = None,
 ):
     is_managed = True
     if not content:
         is_managed = False
     elif not isinstance(content, (str, V1Operation)):
         raise PolyaxonClientException(
             "Received an invalid content: {}".format(content))
     if content:
         content = (content if isinstance(content, str) else
                    content.to_dict(dump=True))
     data = polyaxon_sdk.V1OperationBody(
         name=name,
         description=description,
         tags=tags,
         content=content,
         is_managed=is_managed,
     )
     self._create(data=data, async_req=False)
     self._post_create()
예제 #9
0
    def process_summary(
        cls,
        summary,
        global_step=None,
        run=None,
        log_image: bool = False,
        log_histo: bool = False,
        log_tensor: bool = False,
    ):
        run = tracking.get_or_create_run(run)
        if not run:
            return

        if isinstance(summary, bytes):
            summary_proto = summary_pb2.Summary()
            summary_proto.ParseFromString(summary)
            summary = summary_proto

        step = cls._process_step(global_step)
        for value in summary.value:
            try:
                cls.add_value(
                    run=run,
                    step=step,
                    value=value,
                    log_image=log_image,
                    log_histo=log_histo,
                    log_tensor=log_tensor,
                )
            except PolyaxonClientException(
                "Polyaxon failed processing tensorboard summary."
            ):
                pass
예제 #10
0
def get_local_owner(is_cli: bool = False):
    from polyaxon import settings

    owner = None
    if UserConfigManager.is_initialized():
        try:
            user_config = UserConfigManager.get_config()
            owner = user_config.organization
        except TypeError:
            Printer.print_error(
                "Found an invalid user config or user config cache, "
                "if you are using Polyaxon CLI please run: "
                "`polyaxon config purge --cache-only`",
                sys_exit=True,
            )

    if not owner and (not settings.CLI_CONFIG or settings.CLI_CONFIG.is_ce):
        owner = DEFAULT

    if not owner:
        error = "An context owner (user or organization) is required."
        if is_cli:
            Printer.print_error(error)
            sys.exit(1)
        else:
            raise PolyaxonClientException(error)
    return owner
예제 #11
0
def _read_from_polyaxon_hub(hub: str):
    from polyaxon.client import PolyaxonClient
    from polyaxon.constants.globals import DEFAULT_HUB, NO_AUTH
    from polyaxon.env_vars.getters import get_component_info
    from polyaxon.schemas.cli.client_config import ClientConfig

    owner, component, version = get_component_info(hub)

    try:
        if owner == DEFAULT_HUB:
            config = ClientConfig()
            client = PolyaxonClient(
                config=config,
                token=NO_AUTH,
            )
        else:
            client = PolyaxonClient()
        response = client.component_hub_v1.get_component_version(
            owner, component, version
        )
        return _read_from_stream(response.content)
    except (ApiException, HTTPError) as e:
        raise PolyaxonClientException(
            "Component `{}` could not be fetched, "
            "an error was encountered".format(hub, e)
        )
예제 #12
0
파일: run.py 프로젝트: smilee/polyaxon
    def get_statuses(
            self,
            last_status: str = None) -> Tuple[str, List[V1StatusCondition]]:
        """Gets the run's statuses.

        [Run API](/docs/api/#operation/GetRunStatus)

        Args:
            last_status: str, a valid [Statuses](/docs/core/specification/lifecycle/) value.

        Returns:
            Tuple[str, List[Conditions]], last status and ordered status conditions.
        """
        try:
            response = self.client.runs_v1.get_run_statuses(
                self.owner, self.project, self.run_uuid)
            if not last_status:
                return response.status, response.status_conditions
            if last_status == response.status:
                return last_status, []

            _conditions = []
            for c in reversed(response.status_conditions):
                if c.type == last_status:
                    break
                _conditions.append(c)

            return response.status, reversed(_conditions)

        except (ApiException, HTTPError) as e:
            raise PolyaxonClientException("Api error: %s" % e) from e
예제 #13
0
파일: run.py 프로젝트: eef808a24ff/polyaxon
    def create(
        self,
        name: str = None,
        description: str = None,
        tags: Union[str, Sequence[str]] = None,
        content: Union[str, Dict, V1Operation] = None,
    ):
        """Creates a new run based on the data passed.

        N.B. Create methods are only useful if you want to create a run programmatically,
        if you run a component/operation from the CLI/UI an instance will be created automatically.

        This is a generic create function, you can check other methods for creating runs:
          * from yaml: `create_from_polyaxonfile`
          * from url: `create_from_url`
          * from hub: `create_from_hub`

        > Note that if you don't pass `content`, the creation will pass,
        and the run will be marked as non-managed.

        [Run API](/docs/api/#operation/CreateRun)

        Args:
            name: str, optional, it will override the name in the operation if provided.
            description: str, optional,
                it will override the description in the operation if provided.
            tags: str or List[str], optional, list of tags,
                it will override the tags in the operation if provided.
            content: str or Dict or V1Operation, optional.

        Returns:
            V1Run, run instance from the response.
        """
        is_managed = True
        if not content:
            is_managed = False
        elif not isinstance(content, (str, Mapping, V1Operation)):
            raise PolyaxonClientException(
                "Received an invalid content: {}".format(content)
            )
        if content:
            if isinstance(content, Mapping):
                content = V1Operation.from_dict(content)
            content = (
                content if isinstance(content, str) else content.to_dict(dump=True)
            )
        data = polyaxon_sdk.V1OperationBody(
            name=name,
            description=description,
            tags=tags,
            content=content,
            is_managed=is_managed,
        )
        self._create(data=data, async_req=False)
        self._post_create()
        return self.run_data
예제 #14
0
파일: run.py 프로젝트: klonggan/polyaxon
 def set_outputs_store(self,
                       outputs_store=None,
                       outputs_path=None,
                       set_env_vars=False):
     if not any([outputs_store, outputs_path]):
         raise PolyaxonClientException(
             "An Store instance or and outputs path is required.")
     self.outputs_store = outputs_store or StoreManager(path=outputs_path)
     if self.outputs_store and set_env_vars:
         self.outputs_store.set_env_vars()
예제 #15
0
def impersonate(owner, project, run_uuid):
    try:
        response = PolyaxonClient().runs_v1.impersonate_token(
            owner, project, run_uuid)
        polyaxon_client = PolyaxonClient(token=response.token)
        user = polyaxon_client.users_v1.get_user()
        access_token = AccessTokenConfig(username=user.username,
                                         token=response.token)
        create_context_auth(access_token)
    except (ApiException, HTTPError) as e:
        raise PolyaxonClientException(
            "This worker is not allowed to run this job %s." % e)
예제 #16
0
def get_project_or_local(project=None, is_cli: bool = False):
    if not project and not ProjectManager.is_initialized():
        if is_cli:
            Printer.print_error("Please provide a valid project.")
            sys.exit(1)
        else:
            raise PolyaxonClientException("Please provide a valid project.")

    if project:
        user, project_name = get_project_info(project)
    else:
        project = ProjectManager.get_config()
        user, project_name = project.user, project.name

    if not all([user, project_name]):
        if is_cli:
            Printer.print_error("Please provide a valid project.")
            sys.exit(1)
        else:
            raise PolyaxonClientException("Please provide a valid project.")
    return user, project_name
예제 #17
0
def _read_from_public_hub(hub: str):
    hub_values = hub.split(":")
    if len(hub_values) > 2:
        raise PolyaxonSchemaError(
            "Received an invalid hub reference: `{}`".format(hub))
    if len(hub_values) == 2:
        hub_name, version = hub_values
    else:
        hub_name, version = hub_values[0], "latest"
    version = version or "latest"
    registry = get_default_registry()
    url = "{}/{}/{}.yaml".format(registry, hub_name, version)
    try:
        return _read_from_url(url)
    except HTTPError as e:
        if e.response.status_code == 404:
            raise PolyaxonClientException(
                "Component `{}` was not found, "
                "please check that the name and tag are valid".format(hub))
        raise PolyaxonClientException("Component `{}` could not be fetched, "
                                      "an error was encountered".format(
                                          hub, e))
예제 #18
0
    def create(
        self,
        name: str = None,
        description: str = None,
        tags: Union[str, Sequence[str]] = None,
        content: Union[str, Dict, V1Operation] = None,
    ):
        """Creates a new run based on the data passed.

        This is a generic create function, you can check other methods for creating runs:
          * from yaml
          * from hub
          * from url

        Note that if you don't pass data, the creation will pass,
        and the run will be marked as non-managed.

        [Run API](/docs/api/#operation/CreateRun)

        Args:
            name: str, optional, name
                note it will override the name in the operation if available.
            description: str, optional, description
                note it will override the description in the operation if available.
            tags: str or List[str], optional, list of tags,
                note it will override the tags in the operation if available.
            content: str or Dict or V1Operation, optional.

        Returns:
            V1Run, run instance from the response.
        """
        is_managed = True
        if not content:
            is_managed = False
        elif not isinstance(content, (str, Mapping, V1Operation)):
            raise PolyaxonClientException(
                "Received an invalid content: {}".format(content))
        if content:
            if isinstance(content, Mapping):
                content = V1Operation.from_dict(content)
            content = (content if isinstance(content, str) else
                       content.to_dict(dump=True))
        data = polyaxon_sdk.V1OperationBody(
            name=name,
            description=description,
            tags=tags,
            content=content,
            is_managed=is_managed,
        )
        self._create(data=data, async_req=False)
        self._post_create()
예제 #19
0
 def get_periodic_http_worker(self, **kwargs):
     worker = self.periodic_http_worker
     if not worker or not worker.is_alive():
         if "request" not in kwargs:
             raise PolyaxonClientException(
                 "Periodic worker expects a request argument.")
         self._periodic_http_worker = PeriodicWorker(
             callback=self.queue_periodic_request,
             worker_interval=settings.CLIENT_CONFIG.interval,
             worker_timeout=settings.CLIENT_CONFIG.timeout,
             kwargs=kwargs,
         )
         self._periodic_http_worker.start()
     return self.periodic_http_worker
예제 #20
0
파일: run.py 프로젝트: klonggan/polyaxon
    def __init__(
        self,
        owner=None,
        project=None,
        run_uuid=None,
        client=None,
        track_logs=True,
        track_code=True,
        track_env=False,
        outputs_store=None,
    ):

        owner, project = get_project_info(owner=owner, project=project)

        if project is None:
            if settings.CLIENT_CONFIG.is_managed:
                owner, project, _run_uuid = self.get_run_info()
                run_uuid = run_uuid or _run_uuid
            else:
                raise PolyaxonClientException(
                    "Please provide a valid project.")

        self.status = None
        self.client = client
        if not (self.client or settings.CLIENT_CONFIG.is_offline):
            self.client = PolyaxonClient()

        self.track_logs = track_logs
        self.track_code = track_code
        self.track_env = track_env
        self._owner = owner
        self._project = project
        self._run_uuid = run_uuid
        self.outputs_store = outputs_store

        # Setup the outputs store
        if outputs_store is None and settings.CLIENT_CONFIG.is_managed:
            self.set_outputs_store(outputs_path=get_outputs_path(),
                                   set_env_vars=True)

        self._run = polyaxon_sdk.V1Run()
        if settings.CLIENT_CONFIG.is_offline:
            return

        if self._run_uuid:
            self.refresh_data()

        # Track run env
        if settings.CLIENT_CONFIG.is_managed and self.track_env:
            self.log_run_env()
예제 #21
0
    def process_summary(cls, summary, global_step=None, run=None):
        if not run:
            return

        if isinstance(summary, bytes):
            summary_proto = summary_pb2.Summary()
            summary_proto.ParseFromString(summary)
            summary = summary_proto

        step = cls._process_step(global_step)
        for value in summary.value:
            try:
                cls.add_value(run=run, step=step, value=value)
            except PolyaxonClientException(
                    "Polyaxon failed processing tensorboard summary."):
                pass
예제 #22
0
    def __init__(
        self,
        owner: str = None,
        project: str = None,
        client: PolyaxonClient = None,
    ):
        if not owner and project:
            owner, project = get_entity_info(
                get_entity_full_name(owner=owner, entity=project))

        if not owner:
            raise PolyaxonClientException("Please provide a valid owner.")

        self._client = client
        self._owner = owner or DEFAULT
        self._project = project
        self._project_data = polyaxon_sdk.V1Project()
예제 #23
0
    def __init__(
        self, owner: str = None, project: str = None, client: PolyaxonClient = None,
    ):
        if not owner and project:
            owner, project = get_project_info(
                get_project_full_name(owner=owner, project=project)
            )

        if not owner:
            raise PolyaxonClientException("Please provide a valid project owner.")

        self.client = client
        if not (self.client or settings.CLIENT_CONFIG.is_offline):
            self.client = PolyaxonClient()

        self._owner = owner
        self._project = project
        self._project_data = polyaxon_sdk.V1Project()
예제 #24
0
    def check_response_status(self, response, endpoint):
        """Check if response is successful. Else raise Exception."""

        if 200 <= response.status_code < 300:
            return response

        try:
            logger.error(
                "Request to %s failed with status code %s. \n"
                "Reason: %s",
                endpoint,
                response.status_code,
                response.text,
            )
        except TypeError:
            logger.error("Request to %s failed with status code", endpoint)

        raise PolyaxonClientException(
            HTTP_ERROR_MESSAGES_MAPPING.get(response.status_code))
예제 #25
0
def _get_run_statuses(owner, project, run_uuid, last_status=None):
    try:
        polyaxon_client = PolyaxonClient()
        response = polyaxon_client.runs_v1.get_run_statuses(
            owner, project, run_uuid)
        if not last_status:
            return response.status, response.status_conditions
        if last_status == response.status:
            return [last_status, []]

        conditions = []
        for c in reversed(response.status_conditions):
            if c.type == last_status:
                break
            conditions.append(c)

        return response.status, reversed(conditions)

    except (ApiException, HTTPError) as e:
        raise PolyaxonClientException(e)
예제 #26
0
    def sdk_config(self):
        if not self.host and not self.in_cluster:
            raise PolyaxonClientException(
                "Api config requires at least a host if not running in-cluster."
            )

        config = polyaxon_sdk.Configuration()
        config.debug = self.debug
        config.host = self.host
        config.verify_ssl = self.verify_ssl
        config.ssl_ca_cert = self.ssl_ca_cert
        config.cert_file = self.cert_file
        config.key_file = self.key_file
        config.assert_hostname = self.assert_hostname
        if self.connection_pool_maxsize:
            config.connection_pool_maxsize = self.connection_pool_maxsize
        if self.token:
            config.api_key["Authorization"] = self.token
            config.api_key_prefix["Authorization"] = self.authentication_type
        return config
예제 #27
0
파일: run.py 프로젝트: dorucioclea/polyaxon
    def get_statuses(self, last_status: str = None):

        try:
            response = self.client.runs_v1.get_run_statuses(
                self.owner, self.project, self.run_uuid)
            if not last_status:
                return response.status, response.status_conditions
            if last_status == response.status:
                return last_status, []

            _conditions = []
            for c in reversed(response.status_conditions):
                if c.type == last_status:
                    break
                _conditions.append(c)

            return response.status, reversed(_conditions)

        except (ApiException, HTTPError) as e:
            raise PolyaxonClientException("Api error: %s" % e) from e
예제 #28
0
from typing import List

from polyaxon import tracking
from polyaxon.client.decorators import client_handler
from polyaxon.exceptions import PolyaxonClientException
from polyaxon.logger import logger
from polyaxon.utils.np_utils import sanitize_np_types

try:
    from tensorflow import keras
except ImportError:
    try:
        import keras
    except ImportError:
        raise PolyaxonClientException("Keras is required to use PolyaxonCallback")


class PolyaxonCallback(keras.callbacks.Callback):
    def __init__(
        self,
        run=None,
        metrics: List[str] = None,
        log_model: bool = True,
        save_weights_only: bool = False,
        log_best_prefix="best",
        mode: str = "auto",
        monitor: str = "val_loss",
    ):
        self.run = tracking.get_or_create_run(run)
        self.metrics = metrics
예제 #29
0
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from polyaxon import tracking
from polyaxon.exceptions import PolyaxonClientException
from polyaxon.tracking.contrib.tensorboard import PolyaxonTensorboardLogger

try:
    import tensorflow as tf
except ImportError:
    raise PolyaxonClientException(
        "tensorflow is required to use PolyaxonCallback")

SessionRunHook = None

try:
    from tensorflow.train import SessionRunHook  # noqa
except ImportError:
    pass

try:
    from tensorflow.estimator import LoggingTensorHookSessionRunHook  # noqa
except ImportError:
    pass

if not SessionRunHook:
    raise PolyaxonClientException(
예제 #30
0
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from polyaxon import tracking
from polyaxon.exceptions import PolyaxonClientException

try:
    from tensorflow.keras.callbacks import Callback
    from tensorflow.python.keras.callbacks import ModelCheckpoint
except ImportError:
    try:
        from keras.callbacks import Callback, ModelCheckpoint
    except ImportError:
        raise PolyaxonClientException(
            "Keras is required to use PolyaxonKerasCallback/PolyaxonKerasModelCheckpoint"
        )


class PolyaxonKerasCallback(Callback):
    def __init__(self, run=None, metrics=None):
        self.run = tracking.get_or_create_run(run)
        self.metrics = metrics

    def on_epoch_end(self, epoch, logs=None):
        if not logs or not self.run:
            return
        if self.metrics:
            metrics = {
                metric: logs[metric] for metric in self.metrics if metric in logs
            }