Ejemplo n.º 1
0
def download_h5(run, entity=None, project=None, out_dir=None):
    api = Api()
    meta = api.download_url(project or api.settings(
        "project"), DEEP_SUMMARY_FNAME, entity=entity or api.settings("entity"), run=run)
    if meta and 'md5' in meta and meta['md5'] is not None:
        # TODO: make this non-blocking
        wandb.termlog("Downloading summary data...")
        path, res = api.download_write_file(meta, out_dir=out_dir)
        return path
Ejemplo n.º 2
0
def get_env_vars_section(launch_project: LaunchProject, api: Api,
                         workdir: str) -> str:
    """Fill in wandb-specific environment variables"""

    if _is_wandb_local_uri(
            api.settings("base_url")) and sys.platform == "darwin":
        _, _, port = api.settings("base_url").split(":")
        base_url = "http://host.docker.internal:{}".format(port)
    elif _is_wandb_dev_uri(api.settings("base_url")):
        base_url = "http://host.docker.internal:9002"
    else:
        base_url = api.settings("base_url")
    return "\n".join([
        f"ENV WANDB_BASE_URL={base_url}",
        f"ENV WANDB_API_KEY={api.api_key}",
        f"ENV WANDB_PROJECT={launch_project.target_project}",
        f"ENV WANDB_ENTITY={launch_project.target_entity}",
        f"ENV WANDB_LAUNCH={True}",
        f"ENV WANDB_LAUNCH_CONFIG_PATH={os.path.join(workdir, DEFAULT_LAUNCH_METADATA_PATH)}",
        f"ENV WANDB_RUN_ID={launch_project.run_id or None}",
        f"ENV WANDB_DOCKER={launch_project.docker_image}",
    ])
Ejemplo n.º 3
0
def check_and_download_code_artifacts(entity: str, project: str, run_name: str,
                                      internal_api: Api,
                                      project_dir: str) -> bool:
    _logger.info("Checking for code artifacts")
    public_api = wandb.PublicApi(
        overrides={"base_url": internal_api.settings("base_url")})

    run = public_api.run(f"{entity}/{project}/{run_name}")
    run_artifacts = run.logged_artifacts()

    for artifact in run_artifacts:
        if hasattr(artifact, "type") and artifact.type == "code":
            artifact.download(project_dir)
            return True

    return False
Ejemplo n.º 4
0
    def __init__(
        self,
        uri: str,
        api: Api,
        launch_spec: Dict[str, Any],
        target_entity: str,
        target_project: str,
        name: Optional[str],
        docker_config: Dict[str, Any],
        git_info: Dict[str, str],
        overrides: Dict[str, Any],
        resource: str,
        resource_args: Dict[str, Any],
        cuda: Optional[bool],
    ):
        if utils.is_bare_wandb_uri(uri):
            uri = api.settings("base_url") + uri
            _logger.info(f"Updating uri with base uri: {uri}")
        self.uri = uri
        self.api = api
        self.launch_spec = launch_spec
        self.target_entity = target_entity
        self.target_project = target_project
        self.name = name
        self.build_image: bool = docker_config.get("build_image", False)
        self.python_version: Optional[str] = docker_config.get(
            "python_version")
        self.cuda_version: Optional[str] = docker_config.get("cuda_version")
        self._base_image: Optional[str] = docker_config.get("base_image")
        self.docker_image: Optional[str] = docker_config.get("docker_image")
        uid = RESOURCE_UID_MAP.get(resource, 1000)
        if self._base_image:
            uid = docker.get_image_uid(self._base_image)
            _logger.info(f"Retrieved base image uid {uid}")
        self.docker_user_id: int = docker_config.get("user_id", uid)
        self.git_version: Optional[str] = git_info.get("version")
        self.git_repo: Optional[str] = git_info.get("repo")
        self.override_args: Dict[str, Any] = overrides.get("args", {})
        self.override_config: Dict[str, Any] = overrides.get("run_config", {})
        self.resource = resource
        self.resource_args = resource_args
        self.deps_type: Optional[str] = None
        self.cuda = cuda
        self._runtime: Optional[str] = None
        self.run_id = generate_id()
        self._entry_points: Dict[str, EntryPoint] = {
        }  # todo: keep multiple entrypoint support?
        if "entry_point" in overrides:
            _logger.info("Adding override entry point")
            self.add_entry_point(overrides["entry_point"])
        if utils._is_wandb_uri(self.uri):
            _logger.info(f"URI {self.uri} indicates a wandb uri")
            self.source = LaunchSource.WANDB
            self.project_dir = tempfile.mkdtemp()
        elif utils._is_git_uri(self.uri):
            _logger.info(f"URI {self.uri} indicates a git uri")
            self.source = LaunchSource.GIT
            self.project_dir = tempfile.mkdtemp()
        else:
            _logger.info(f"URI {self.uri} indicates a local uri")
            # assume local
            if not os.path.exists(self.uri):
                raise LaunchError(
                    "Assumed URI supplied is a local path but path is not valid"
                )
            self.source = LaunchSource.LOCAL
            self.project_dir = self.uri
        if launch_spec.get("resource_args"):
            self.resource_args = launch_spec["resource_args"]

        self.aux_dir = tempfile.mkdtemp()
        self.clear_parameter_run_config_collisions()
Ejemplo n.º 5
0
class LaunchAgent(object):
    """Launch agent class which polls run given run queues and launches runs for wandb launch."""

    def __init__(
        self,
        entity: str,
        project: str,
        queues: Iterable[str] = None,
        max_jobs: int = None,
    ):
        self._entity = entity
        self._project = project
        self._api = Api()
        self._settings = wandb.Settings()
        self._base_url = self._api.settings().get("base_url")
        self._jobs: Dict[Union[int, str], AbstractRun] = {}
        self._ticks = 0
        self._running = 0
        self._cwd = os.getcwd()
        self._namespace = wandb.util.generate_id()
        self._access = _convert_access("project")
        self._max_jobs = max_jobs or 1

        # serverside creation
        self.gorilla_supports_agents = (
            self._api.launch_agent_introspection() is not None
        )
        create_response = self._api.create_launch_agent(
            entity, project, queues, self.gorilla_supports_agents
        )
        self._id = create_response["launchAgentId"]
        self._name = ""  # hacky: want to display this to the user but we don't get it back from gql until polling starts. fix later
        self._queues = queues if queues else ["default"]

    @property
    def job_ids(self) -> List[Union[int, str]]:
        """Returns a list of keys running job ids for the agent."""
        return list(self._jobs.keys())

    def pop_from_queue(self, queue: str) -> Any:
        """Pops an item off the runqueue to run as a job."""
        try:
            ups = self._api.pop_from_run_queue(
                queue, entity=self._entity, project=self._project, agent_id=self._id,
            )
        except Exception as e:
            print("Exception:", e)
            return None
        return ups

    def print_status(self) -> None:
        """Prints the current status of the agent."""
        wandb.termlog(
            "agent {} polling on project {}, queues {} for jobs".format(
                self._name, self._project, " ".join(self._queues)
            )
        )

    def update_status(self, status: str) -> None:
        update_ret = self._api.update_launch_agent_status(
            self._id, status, self.gorilla_supports_agents
        )
        if not update_ret["success"]:
            wandb.termerror("Failed to update agent status to {}".format(status))

    def finish_job_id(self, job_id: Union[str, int]) -> None:
        """Removes the job from our list for now."""
        # TODO:  keep logs or something for the finished jobs
        del self._jobs[job_id]
        self._running -= 1
        # update status back to polling if no jobs are running
        if self._running == 0:
            self.update_status(AGENT_POLLING)

    def _update_finished(self, job_id: Union[int, str]) -> None:
        """Check our status enum."""
        if self._jobs[job_id].get_status().state in ["failed", "finished"]:
            self.finish_job_id(job_id)

    def _validate_and_fix_spec_project_entity(
        self, launch_spec: Dict[str, Any]
    ) -> None:
        """Checks if launch spec target project/entity differs from agent. Forces these values to agent's if they are set."""
        if (
            launch_spec.get("project") is not None
            and launch_spec.get("project") != self._project
        ) or (
            launch_spec.get("entity") is not None
            and launch_spec.get("entity") != self._entity
        ):
            wandb.termwarn(
                f"Launch agents only support sending runs to their own project and entity. This run will be sent to {self._entity}/{self._project}"
            )
            launch_spec["entity"] = self._entity
            launch_spec["project"] = self._project

    def run_job(self, job: Dict[str, Any]) -> None:
        """Sets up project and runs the job."""
        # TODO: logger
        wandb.termlog(f"agent: got job f{job}")
        _logger.info(f"Agent job: {job}")
        # update agent status
        self.update_status(AGENT_RUNNING)

        # parse job
        _logger.info("Parsing launch spec")
        launch_spec = job["runSpec"]
        if launch_spec.get("overrides") and isinstance(
            launch_spec["overrides"].get("args"), list
        ):
            launch_spec["overrides"]["args"] = util._user_args_to_dict(
                launch_spec["overrides"].get("args", [])
            )
        self._validate_and_fix_spec_project_entity(launch_spec)

        project = create_project_from_spec(launch_spec, self._api)
        _logger.info("Fetching and validating project...")
        project = fetch_and_validate_project(project, self._api)
        _logger.info("Fetching resource...")
        resource = launch_spec.get("resource") or "local"
        backend_config: Dict[str, Any] = {
            PROJECT_DOCKER_ARGS: {},
            PROJECT_SYNCHRONOUS: False,  # agent always runs async
        }
        if _is_wandb_local_uri(self._base_url):
            _logger.info(
                "Noted a local URI. Setting local network arguments for docker"
            )
            if sys.platform == "win32":
                backend_config[PROJECT_DOCKER_ARGS]["net"] = "host"
            else:
                backend_config[PROJECT_DOCKER_ARGS]["network"] = "host"
            if sys.platform == "linux" or sys.platform == "linux2":
                backend_config[PROJECT_DOCKER_ARGS][
                    "add-host"
                ] = "host.docker.internal:host-gateway"

        backend_config["runQueueItemId"] = job["runQueueItemId"]
        _logger.info("Loading backend")
        backend = load_backend(resource, self._api, backend_config)
        backend.verify()
        _logger.info("Backend loaded...")
        run = backend.run(project)
        if run:
            self._jobs[run.id] = run
            self._running += 1

    def loop(self) -> None:
        """Main loop function for agent."""
        wandb.termlog(
            "launch agent polling project {}/{} on queues: {}".format(
                self._entity, self._project, ",".join(self._queues)
            )
        )
        try:
            while True:
                self._ticks += 1
                job = None
                if self._running < self._max_jobs:
                    # only check for new jobs if we're not at max
                    for queue in self._queues:
                        job = self.pop_from_queue(queue)
                        if job:
                            self.run_job(job)
                            break  # do a full housekeeping loop before popping more jobs

                agent_response = self._api.get_launch_agent(
                    self._id, self.gorilla_supports_agents
                )
                self._name = agent_response[
                    "name"
                ]  # hacky, but we don't return the name on create so this is first time
                if agent_response["stopPolling"]:
                    # shutdown process and all jobs if requested from ui
                    raise KeyboardInterrupt
                for job_id in self.job_ids:
                    self._update_finished(job_id)
                if self._ticks % 2 == 0:
                    if self._running == 0:
                        self.update_status(AGENT_POLLING)
                        self.print_status()
                    else:
                        self.update_status(AGENT_RUNNING)
                time.sleep(AGENT_POLLING_INTERVAL)

        except KeyboardInterrupt:
            # temp: for local, kill all jobs. we don't yet have good handling for different
            # types of runners in general
            for _, run in self._jobs.items():
                if isinstance(run, LocalSubmittedRun):
                    run.command_proc.kill()
            self.update_status(AGENT_KILLED)
            wandb.termlog("Shutting down, active jobs:")
            self.print_status()
Ejemplo n.º 6
0
def run(
    uri: str,
    api: Api,
    entry_point: Optional[str] = None,
    version: Optional[str] = None,
    parameters: Optional[Dict[str, Any]] = None,
    name: Optional[str] = None,
    resource: str = "local",
    resource_args: Optional[Dict[str, Any]] = None,
    project: Optional[str] = None,
    entity: Optional[str] = None,
    docker_image: Optional[str] = None,
    config: Optional[Dict[str, Any]] = None,
    synchronous: Optional[bool] = True,
    cuda: Optional[bool] = None,
) -> AbstractRun:
    """Run a W&B launch experiment. The project can be wandb uri or a Git URI.

    Arguments:
    uri: URI of experiment to run. A wandb run uri or a Git repository URI.
    api: An instance of a wandb Api from wandb.apis.internal.
    entry_point: Entry point to run within the project. Defaults to using the entry point used
        in the original run for wandb URIs, or main.py for git repository URIs.
    version: For Git-based projects, either a commit hash or a branch name.
    parameters: Parameters (dictionary) for the entry point command. Defaults to using the
        the parameters used to run the original run.
    name: Name run under which to launch the run.
    resource: Execution backend for the run: W&B provides built-in support for "local" backend
    resource_args: Resource related arguments for launching runs onto a remote backend.
        Will be stored on the constructed launch config under ``resource_args``.
    project: Target project to send launched run to
    entity: Target entity to send launched run to
    config: A dictionary containing the configuration for the run. May also contain
    resource specific arguments under the key "resource_args".
    synchronous: Whether to block while waiting for a run to complete. Defaults to True.
        Note that if ``synchronous`` is False and ``backend`` is "local", this
        method will return, but the current process will block when exiting until
        the local run completes. If the current process is interrupted, any
        asynchronous runs launched via this method will be terminated. If
        ``synchronous`` is True and the run fails, the current process will
        error out as well.
    cuda: Whether to build a CUDA-enabled docker image or not


    Example:
        import wandb
        project_uri = "https://github.com/wandb/examples"
        params = {"alpha": 0.5, "l1_ratio": 0.01}
        # Run W&B project and create a reproducible docker environment
        # on a local host
        api = wandb.apis.internal.Api()
        wandb.launch(project_uri, api, parameters=params)


    Returns:
        an instance of`wandb.launch.SubmittedRun` exposing information (e.g. run ID)
        about the launched run.

    Raises:
        `wandb.exceptions.ExecutionError` If a run launched in blocking mode
        is unsuccessful.
    """
    docker_args = {}

    if _is_wandb_local_uri(api.settings("base_url")):
        if sys.platform == "win32":
            docker_args["net"] = "host"
        else:
            docker_args["network"] = "host"
        if sys.platform == "linux" or sys.platform == "linux2":
            docker_args["add-host"] = "host.docker.internal:host-gateway"

    if config is None:
        config = {}

    if "docker" in config:
        docker_args.update(config["docker"])  # userprovided args override
    config["docker"] = docker_args

    submitted_run_obj = _run(
        uri=uri,
        name=name,
        project=project,
        entity=entity,
        docker_image=docker_image,
        entry_point=entry_point,
        version=version,
        parameters=parameters,
        resource=resource,
        resource_args=resource_args,
        launch_config=config,
        synchronous=synchronous,
        cuda=cuda,
        api=api,
    )

    return submitted_run_obj