Example #1
0
    def run_job(self, job):
        run_id = job.run_id

        config_file = os.path.join("wandb", "sweep-" + self._sweep_id,
                                   "config-" + run_id + ".yaml")
        config_util.save_config_file_from_dict(config_file, job.config)
        os.environ[wandb.env.RUN_ID] = run_id
        os.environ[wandb.env.CONFIG_PATHS] = config_file
        os.environ[wandb.env.SWEEP_ID] = self._sweep_id
        wandb.setup(_reset=True)

        print("wandb: Agent Starting Run: {} with config:\n".format(run_id) +
              "\n".join([
                  "\t{}: {}".format(k, v["value"])
                  for k, v in job.config.items()
              ]))
        try:
            self._function()
            if wandb.run:
                wandb.join()
        except KeyboardInterrupt as e:
            print("Keyboard interrupt", e)
            return True
        except Exception as e:
            print("Problem", e)
            return True
Example #2
0
    def _run_job(self, job):
        try:
            run_id = job.run_id

            config_file = os.path.join("wandb", "sweep-" + self._sweep_id,
                                       "config-" + run_id + ".yaml")
            os.environ[wandb.env.RUN_ID] = run_id
            os.environ[wandb.env.CONFIG_PATHS] = os.path.join(
                os.environ[wandb.env.DIR], config_file)
            config_util.save_config_file_from_dict(
                os.environ[wandb.env.CONFIG_PATHS], job.config)
            os.environ[wandb.env.SWEEP_ID] = self._sweep_id
            wandb_sdk.wandb_setup._setup(_reset=True)

            wandb.termlog("Agent Starting Run: {} with config:".format(run_id))
            for k, v in job.config.items():
                wandb.termlog("\t{}: {}".format(k, v["value"]))

            self._function()
            wandb.finish()
        except KeyboardInterrupt as ki:
            raise ki
        except Exception as e:
            wandb.finish(exit_code=1)
            if run_id in self._stopped_runs:
                self._stopped_runs.remove(run_id)
                # wandb.termlog("Stopping run: " + str(run_id))
            else:
                self._errored_runs[run_id] = e
Example #3
0
 def send_config(self, data):
     cfg = data.config
     config_dict = config_util.dict_from_proto_list(cfg.update)
     self._api.upsert_run(name=self._run.run_id,
                          config=config_dict,
                          **self._api_settings)
     config_path = os.path.join(self._settings.files_dir, "config.yaml")
     config_util.save_config_file_from_dict(config_path, config_dict)
Example #4
0
    def send_run(self, data):
        run = data.run
        error = None
        is_wandb_init = self._run is None

        # build config dict
        config_dict = None
        config_path = os.path.join(self._settings.files_dir,
                                   filenames.CONFIG_FNAME)
        if run.config:
            config_dict = config_util.dict_from_proto_list(run.config.update)
            config_util.save_config_file_from_dict(config_path, config_dict)

        if is_wandb_init:
            # Ensure we have a project to query for status
            if run.project == "":
                run.project = util.auto_project_name(self._settings.program)
            # Only check resume status on `wandb.init`
            error = self._maybe_setup_resume(run)

        if error is not None:
            if data.control.req_resp:
                resp = wandb_internal_pb2.Result(uuid=data.uuid)
                resp.run_result.run.CopyFrom(run)
                resp.run_result.error.CopyFrom(error)
                self._result_q.put(resp)
            else:
                logger.error("Got error in async mode: %s", error.message)
            return

        # Save the resumed config
        if self._resume_state["config"] is not None:
            # TODO: should we merge this with resumed config?
            config_override = config_dict or {}
            config_dict = self._resume_state["config"]
            config_dict.update(config_override)
            config_util.save_config_file_from_dict(config_path, config_dict)

        self._init_run(run, config_dict)

        if data.control.req_resp:
            resp = wandb_internal_pb2.Result(uuid=data.uuid)
            # TODO: we could do self._interface.publish_defer(resp) to notify
            # the handler not to actually perform server updates for this uuid
            # because the user process will send a summary update when we resume
            resp.run_result.run.CopyFrom(self._run)
            self._result_q.put(resp)

        # Only spin up our threads on the first run message
        if is_wandb_init:
            self._start_run_threads()
        else:
            logger.info("updated run: %s", self._run.run_id)
    def _command_run(self, command):
        logger.info("Agent starting run with config:\n" + "\n".join([
            "\t{}: {}".format(k, v["value"])
            for k, v in command["args"].items()
        ]))
        if self._in_jupyter:
            print("wandb: Agent Starting Run: {} with config:\n".format(
                command.get("run_id")) + "\n".join([
                    "\t{}: {}".format(k, v["value"])
                    for k, v in command["args"].items()
                ]))

        # setup default sweep command if not configured
        sweep_command = self._sweep_command or [
            "${env}",
            "${interpreter}",
            "${program}",
            "${args}",
        ]

        run_id = command.get("run_id")
        sweep_id = os.environ.get(wandb.env.SWEEP_ID)
        # TODO(jhr): move into settings
        config_file = os.path.join("wandb", "sweep-" + sweep_id,
                                   "config-" + run_id + ".yaml")
        json_file = os.path.join("wandb", "sweep-" + sweep_id,
                                 "config-" + run_id + ".json")
        config_util.save_config_file_from_dict(config_file, command["args"])
        os.environ[wandb.env.RUN_ID] = run_id
        os.environ[wandb.env.CONFIG_PATHS] = config_file

        env = dict(os.environ)

        flags_list = [(param, config["value"])
                      for param, config in command["args"].items()]
        flags_no_hyphens = [
            "{}={}".format(param, value) for param, value in flags_list
        ]
        flags = ["--" + flag for flag in flags_no_hyphens]
        flags_dict = dict(flags_list)
        flags_json = json.dumps(flags_dict)

        if "${args_json_file}" in sweep_command:
            with open(json_file, "w") as fp:
                fp.write(flags_json)

        if self._function:
            proc = AgentProcess(
                function=self._function,
                env=env,
                run_id=run_id,
                in_jupyter=self._in_jupyter,
            )
        else:
            sweep_vars = dict(
                interpreter=["python"],
                program=[command["program"]],
                args=flags,
                args_no_hyphens=flags_no_hyphens,
                args_json=[flags_json],
                args_json_file=[json_file],
                env=["/usr/bin/env"],
            )
            if platform.system() == "Windows":
                del sweep_vars["env"]
            command_list = []
            for c in sweep_command:
                c = str(c)
                if c.startswith("${") and c.endswith("}"):
                    replace_list = sweep_vars.get(c[2:-1])
                    command_list += replace_list or []
                else:
                    command_list += [c]
            logger.info("About to run command: {}".format(" ".join(
                '"%s"' % c if " " in c else c for c in command_list)))
            proc = AgentProcess(command=command_list, env=env)
        self._run_processes[run_id] = proc

        # we keep track of when we sent the sigterm to give processes a chance
        # to handle the signal before sending sigkill every heartbeat
        self._run_processes[run_id].last_sigterm_time = None
        self._last_report_time = None