コード例 #1
0
ファイル: init.py プロジェクト: microsoft/openpai-runtime
def main():
    [plugin_config, pre_script, post_script] = plugin_init()

    plugin_helper = PluginHelper(plugin_config)
    parameters = plugin_config.get("parameters")
    if parameters:
        if "callbacks" in parameters:
            assert "preCommands" not in parameters
            assert "postCommands" not in parameters
            pre_commands = []
            post_commands = []
            for callback in parameters['callbacks']:
                if callback['event'] == 'taskStarts':
                    pre_commands.extend(callback['commands'])
                elif callback['event'] == 'taskSucceeds':
                    post_commands.extend(callback['commands'])
            if len(pre_commands) > 0:
                plugin_helper.inject_commands(pre_commands, pre_script)
            if len(post_commands) > 0:
                plugin_helper.inject_commands(post_commands, post_script)
        else:
            if "preCommands" in parameters:
                plugin_helper.inject_commands(parameters["preCommands"],
                                              pre_script)
            if "postCommands" in parameters:
                plugin_helper.inject_commands(parameters["postCommands"],
                                              post_script)
コード例 #2
0
def main():
    LOGGER.info("Preparing ssh runtime plugin commands")
    [plugin_config, pre_script, _] = plugin_init()
    plugin_helper = PluginHelper(plugin_config)
    parameters = plugin_config.get("parameters")

    if not parameters:
        LOGGER.info("Ssh plugin parameters is empty, ignore this")
        return

    gang_allocation = os.environ.get("GANG_ALLOCATION", "true")
    if gang_allocation == "false":
        LOGGER.warning(
            "Job ssh is conflict with gang allocation, set job ssh to false")
        jobssh = "false"
    elif "jobssh" in parameters:
        jobssh = str(parameters["jobssh"]).lower()
    else:
        jobssh = "false"
    cmd_params = [jobssh]

    if "userssh" in parameters:
        if "type" in parameters["userssh"] and "value" in parameters["userssh"]:
            cmd_params.append(str(parameters["userssh"]["type"]))
            cmd_params.append("\'{}\'".format(parameters["userssh"]["value"]))

    # write call to real executable script
    command = []
    if len(cmd_params) == 1 and cmd_params[0] == "false":
        LOGGER.info("Skip sshd script since neither jobssh or userssh is set")
    else:
        command = [
            try_to_install_by_cache(
                "ssh",
                fallback_cmds=[
                    "apt-get update",
                    "apt-get install -y openssh-client openssh-server",
                ]), "{}/sshd.sh {}\n".format(
                    os.path.dirname(os.path.abspath(__file__)),
                    " ".join(cmd_params))
        ]

    # ssh barrier
    if jobssh == "true" and "sshbarrier" in parameters and str(
            parameters["sshbarrier"]).lower() == "true":
        if "sshbarrierTimeout" in parameters:
            barrier_params = str(parameters["sshbarrierTimeout"])
        else:
            barrier_params = ""
        command.append("{}/sshbarrier.sh {}\n".format(
            os.path.dirname(os.path.abspath(__file__)), barrier_params))

    plugin_helper.inject_commands(command, pre_script)
    LOGGER.info("Ssh runtime plugin perpared")
コード例 #3
0
ファイル: init.py プロジェクト: sycomix/pai
def main():
    LOGGER.info("Preparing tensorboard runtime plugin commands")

    [plugin_config, pre_script, _] = plugin_init()
    parameters = plugin_config.get("parameters")

    if TASK_ROLE_LIST[0] != TASK_ROLE_NAME or TASK_ROLE_INDEX != 0:
        LOGGER.info(
            "Not first taskrole or not first task instance, ignore this plugin"
        )
        return
    if not parameters:
        LOGGER.info("Tensorboard plugin parameters is empty, ignore this")
        return

    current_dir = os.path.dirname(os.path.abspath(__file__))
    template_file = "{}/tensorboard.sh.template".format(current_dir)
    with open("{}/tensorboard.sh".format(current_dir), "w+") as f:
        f.write(generate_tensorboard_commands(template_file, parameters))

    tensorboard_exec_path = "{}/tensorboard.sh".format(current_dir)
    commands = [
        "chmod u+x {}".format(tensorboard_exec_path), tensorboard_exec_path
    ]

    PluginHelper(plugin_config).inject_commands(commands, pre_script)
    LOGGER.info("Tensorboard runtime plugin perpared")
コード例 #4
0
def main():
    '''
    Teamwise plugin is deprecated. Keep this piece of code since we may reuse them to
    support user defined storage.
    '''

    LOGGER.warning("This plugin is deprecated, will ignore this plugin")
    return

    #pylint: disable=unreachable
    LOGGER.info("Preparing storage runtime plugin commands")
    [plugin_config, pre_script, _] = plugin_init()
    parameters = plugin_config.get("parameters", "")

    try:
        command_generator = StorageCommandGenerator()
    except Exception:  #pylint: disable=broad-except
        LOGGER.exception("Failed to generate storage commands")
        sys.exit(1)
    pre_script_commands = command_generator.generate_plugin_commands(
        parameters)

    PluginHelper(plugin_config).inject_commands(pre_script_commands,
                                                pre_script)
    LOGGER.info("Storage runtime plugin perpared")
コード例 #5
0
ファイル: init.py プロジェクト: shuangjiexu/pai
def main():
    LOGGER.info("Preparing ssh runtime plugin commands")
    [plugin_config, pre_script, _] = plugin_init()
    plugin_helper = PluginHelper(plugin_config)
    parameters = plugin_config.get("parameters")

    if not parameters:
        LOGGER.info("Ssh plugin parameters is empty, ignore this")
        return

    if "jobssh" in parameters:
        jobssh = str(parameters["jobssh"]).lower()
    else:
        jobssh = "false"
    cmd_params = [jobssh]

    if "userssh" in parameters:
        if "type" in parameters["userssh"] and "value" in parameters["userssh"]:
            cmd_params.append(str(parameters["userssh"]["type"]))
            cmd_params.append("\'{}\'".format(parameters["userssh"]["value"]))

    # write call to real executable script
    command = [
        "{}/sshd.sh {}\n".format(os.path.dirname(os.path.abspath(__file__)),
                                 " ".join(cmd_params))
    ]

    # ssh barrier
    if jobssh == "true" and "sshbarrier" in parameters and str(
            parameters["sshbarrier"]).lower() == "true":
        if "sshbarriertaskroles" in parameters:
            barrier_params = " ".join(
                '"{}"'.format(tr) for tr in parameters["sshbarriertaskroles"])
        else:
            barrier_params = ""
        command.append("{}/sshbarrier.sh {}\n".format(
            os.path.dirname(os.path.abspath(__file__)), barrier_params))

    plugin_helper.inject_commands(command, pre_script)
    LOGGER.info("Ssh runtime plugin perpared")
コード例 #6
0
def main():
    LOGGER.info("Preparing git runtime plugin")
    [plugin_config, pre_script, _] = plugin_init()
    plugin_helper = PluginHelper(plugin_config)

    cur_dir = os.path.dirname(os.path.abspath(__file__))
    repo_local_path = os.path.join(cur_dir, "../../code")
    parameters = plugin_config.get("parameters")
    if not parameters or "repo_uri" not in parameters:
        LOGGER.error("Can not find repo in runtime plugin")
        sys.exit(1)
    if "options" in parameters:
        Repo.clone_from(parameters["repo_uri"],
                        repo_local_path,
                        multi_options=parameters["options"])
    else:
        Repo.clone_from(parameters["repo_uri"], repo_local_path)
    if "clone_dir" in parameters:
        plugin_helper.inject_commands([
            "{}/check_clone_dir.sh {}".format(
                cur_dir, parameters["clone_dir"]), "mkdir -p {}".format(
                    parameters["clone_dir"]), "mv -f {}/* {}".format(
                        repo_local_path, parameters["clone_dir"])
        ], pre_script)
コード例 #7
0
ファイル: init.py プロジェクト: sycomix/pai
def main():
    LOGGER.info("Preparing storage runtime plugin commands")
    [plugin_config, pre_script, _] = plugin_init()
    parameters = plugin_config.get("parameters", "")

    try:
        command_generator = StorageCommandGenerator()
    except Exception:  #pylint: disable=broad-except
        LOGGER.exception("Failed to generate storage commands")
        sys.exit(1)
    pre_script_commands = command_generator.generate_plugin_commands(
        parameters)

    PluginHelper(plugin_config).inject_commands(pre_script_commands,
                                                pre_script)
    LOGGER.info("Storage runtime plugin perpared")
コード例 #8
0
def main():
    LOGGER.info("Preparing tensorboard runtime plugin commands")
    [plugin_config, pre_script, _] = plugin_init()
    parameters = plugin_config.get("parameters")

    if not parameters:
        LOGGER.info("Tensorboard plugin parameters is empty, ignore this")
        return

    commands = []
    logdir = ",".join(
        ["{}:{}".format(k, v) for k, v in parameters["logdir"].items()])
    commands.append("tensorboard --logdir={} --port={} &\n".format(
        logdir, parameters["port"]))

    PluginHelper(plugin_config).inject_commands(commands, pre_script)
    LOGGER.info("Tensorboard runtime plugin perpared")
コード例 #9
0
def main():
    [plugin_config, pre_script, post_script] = plugin_init()

    plugin_helper = PluginHelper(plugin_config)
    parameters = plugin_config.get("parameters")
    if parameters:
        if "preCommands" in parameters:
            plugin_helper.inject_commands(parameters["preCommands"],
                                          pre_script)
        if "postCommands" in parameters:
            plugin_helper.inject_commands(parameters["postCommands"],
                                          post_script)
コード例 #10
0
ファイル: init.py プロジェクト: MShaffar19/openpai-runtime
def main():
    LOGGER.info("Preparing ssh runtime plugin commands")
    [plugin_config, pre_script, _] = plugin_init()
    plugin_helper = PluginHelper(plugin_config)
    parameters = plugin_config.get("parameters")

    if not parameters:
        LOGGER.info("Ssh plugin parameters is empty, ignore this")
        return

    gang_allocation = os.environ.get("GANG_ALLOCATION", "true")
    if gang_allocation == "false":
        LOGGER.warning(
            "Job ssh is conflict with gang allocation, set job ssh to false")
        jobssh = "false"
    elif "jobssh" in parameters:
        jobssh = str(parameters["jobssh"]).lower()
    else:
        jobssh = "false"
    cmd_params = [jobssh]

    if "userssh" in parameters:
        # get user public keys from rest server
        application_token = plugin_config.get("application_token")
        username = os.environ.get("PAI_USER_NAME")
        public_keys = []
        if application_token:
            try:
                public_keys = get_user_public_keys(application_token, username)
            except Exception:  #pylint: disable=broad-except
                LOGGER.error("Failed to get user public keys", exc_info=True)
                sys.exit(1)

        if "value" in parameters["userssh"] and parameters["userssh"]["value"] != "":
            public_keys.append(parameters["userssh"]["value"])

        # append user public keys to cmd_params
        if "type" in parameters["userssh"] and public_keys:
            cmd_params.append(str(parameters["userssh"]["type"]))
            cmd_params.append("\'{}\'".format('\n'.join(public_keys)))

    # write call to real executable script
    command = []
    if len(cmd_params) == 1 and cmd_params[0] == "false":
        LOGGER.info("Skip sshd script since neither jobssh or userssh is set")
    else:
        command = [
            try_to_install_by_cache(
                "ssh",
                fallback_cmds=[
                    "apt-get update",
                    "apt-get install -y openssh-client openssh-server",
                ]), "{}/sshd.sh {}\n".format(
                    os.path.dirname(os.path.abspath(__file__)),
                    " ".join(cmd_params))
        ]

    # ssh barrier
    if jobssh == "true" and "sshbarrier" in parameters and str(
            parameters["sshbarrier"]).lower() == "true":
        if "sshbarrierTimeout" in parameters:
            barrier_params = str(parameters["sshbarrierTimeout"])
        else:
            barrier_params = ""
        command.append("{}/sshbarrier.sh {}\n".format(
            os.path.dirname(os.path.abspath(__file__)), barrier_params))

    plugin_helper.inject_commands(command, pre_script)
    LOGGER.info("Ssh runtime plugin perpared")