Пример #1
0
def get_num_gpu_used():
    if not remoterun.get_env_var(
            "CUDA_VISIBLE_DEVICES",
            False) or remoterun.get_env_var("CUDA_VISIBLE_DEVICES") == "-1":
        return 0

    else:
        return len(remoterun.get_env_var("CUDA_VISIBLE_DEVICES").split(','))
Пример #2
0
def get_location_data():
    global _cached_location_data, _cached_session
    if _cached_location_data is not None:
        create_session_if_needed()
        return _cached_location_data

    api_ticket = remoterun.get_env_var("DKU_API_TICKET", d=None)

    if api_ticket is not None:
        # We have an API ticket so we are in DSS
        _cached_location_data = {
            "auth_mode": "TICKET",
            "api_ticket": api_ticket
        }

        _cached_location_data["backend_url"] = "http://%s:%s" % \
                            (remoterun.get_env_var("DKU_BACKEND_HOST", "127.0.0.1"),
                                remoterun.get_env_var("DKU_BACKEND_PORT"))

        if os.getenv("DKU_SERVER_KIND", "BACKEND") == "BACKEND":
            _cached_location_data["has_a_jek"] = False
        else:
            _cached_location_data["has_a_jek"] = True
            _cached_location_data["jek_url"] = "http://%s:%s" % (os.getenv(
                "DKU_SERVER_HOST",
                "127.0.0.1"), int(os.getenv("DKU_SERVER_PORT")))

    else:
        # No API ticket so we are running outside of DSS, start the dance to find remote DSS authentication
        # info
        # In that order:
        #   - dataiku.set_remote_dss (has been handled at the top of this method)
        #   - Environment variables DKU_DSS_URL and DKU_API_KEY
        #   - ~/.dataiku/config.json

        if os.getenv("DKU_DSS_URL") is not None:
            set_remote_dss(os.environ["DKU_DSS_URL"],
                           os.environ["DKU_API_KEY"])
        else:
            config_file = osp.expanduser("~/.dataiku/config.json")
            if osp.isfile(config_file):
                with open(config_file) as f:
                    config = json.load(f)

                instance_details = config["dss_instances"][
                    config["default_instance"]]

                set_remote_dss(instance_details["url"],
                               instance_details["api_key"],
                               no_check_certificate=instance_details.get(
                                   "no_check_certificate", False))
            else:
                raise Exception(
                    "No DSS URL or API key found from any location")

    create_session_if_needed()

    return _cached_location_data
Пример #3
0
 def __init__(self):
     self.project_key = remoterun.get_env_var('DKU_CURRENT_PROJECT_KEY')
     if remoterun.has_env_var('DKU_CURRENT_SCENARIO_TRIGGER_FILE'):
         trigger_json_file = remoterun.get_env_var(
             'DKU_CURRENT_SCENARIO_TRIGGER_FILE')
         with open(trigger_json_file, 'r') as f:
             self.scenario_trigger = json.load(f)
     else:
         self.scenario_trigger = None
Пример #4
0
def _get_variable_value(variable, variable_name, os_variable_name):
    if variable is None:

        if not remoterun.get_env_var(os_variable_name, False):
            raise ValueError(
                "You must provide an '{}' argument".format(variable_name))
        else:
            return remoterun.get_env_var(os_variable_name)

    return variable
Пример #5
0
def default_project_key():
    if remoterun.has_env_var("DKU_CURRENT_PROJECT_KEY"):
        return remoterun.get_env_var("DKU_CURRENT_PROJECT_KEY")
    else:
        raise Exception(
            "Default project key is not specified (no DKU_CURRENT_PROJECT_KEY in env)"
        )
Пример #6
0
 def get_state():
     logging.info("poll state")
     remote_kernel = backend_json_call(
         "jupyter/poll-remote-kernel",
         data={
             "contextProjectKey":
             remoterun.get_env_var("DKU_CURRENT_PROJECT_KEY"),
             "batchId":
             self.batch_id
         })
     logging.info("Polled, got : %s" % json.dumps(remote_kernel))
     return remote_kernel.get("state", None)
Пример #7
0
def get_keras_model_location_from_trained_model(session_id=None,
                                                analysis_id=None,
                                                mltask_id=None):
    analysis_id = _get_variable_value(analysis_id, "analysis_id",
                                      constants.DKU_CURRENT_ANALYSIS_ID)
    mltask_id = _get_variable_value(mltask_id, "mltask_id",
                                    constants.DKU_CURRENT_MLTASK_ID)

    # Retrieve info on location of model
    project_key = remoterun.get_env_var("DKU_CURRENT_PROJECT_KEY")
    mltask = dataiku.api_client().get_project(project_key).get_ml_task(
        analysis_id, mltask_id)
    mltask_status = mltask.get_status()

    # Check good backend
    if mltask_status["headSessionTask"]["backendType"] != "KERAS":
        raise ValueError("The mltask you are accessing was not a Keras model")

    # We assume here that there is only one model per session, i.e. session_id are unique
    # in mltask_status["fullModelIds"], which is the case for KERAS backend
    sessions = [
        p["fullModelId"]["sessionId"] for p in mltask_status["fullModelIds"]
    ]
    if session_id is None:
        last_session = sorted(
            [int(sess_id_str[1:]) for sess_id_str in sessions])[-1]
        session_id = "s{}".format(last_session)
    try:
        session_index = sessions.index(session_id)
    except ValueError as e:
        raise ValueError(
            "The 'session_id' you are providing cannot be found in the mltask. "
            "Available session_ids are: {}".format(sessions))

    session = mltask_status["fullModelIds"][session_index]["fullModelId"]

    dip_home = dataiku.core.base.get_dip_home()
    model_folder = os.path.join(dip_home, "analysis-data", project_key,
                                analysis_id, mltask_id, "sessions",
                                session["sessionId"],
                                session["preprocessingId"], session["modelId"])

    model_location = os.path.join(model_folder, constants.KERAS_MODEL_FILENAME)

    if not os.path.isfile(model_location):
        raise ValueError(
            "No model found for this mltask. Did it run without errors ?")

    return model_location
Пример #8
0
def get_auth_headers():
    location_data = get_location_data()

    if location_data["auth_mode"] == "TICKET":
        headers = {"X-DKU-APITicket": location_data["api_ticket"]}
    else:
        auth = requests.auth.HTTPBasicAuth(location_data["api_key"], "")
        fake_req = requests.Request()
        auth(fake_req)
        headers = fake_req.headers

    if remoterun.has_env_var("DKU_CALL_ORIGIN"):
        headers['X-DKU-CallOrigin'] = remoterun.get_env_var("DKU_CALL_ORIGIN")

    return headers
Пример #9
0
def get_keras_model_location_from_saved_model(saved_model_id):
    project_key = remoterun.get_env_var("DKU_CURRENT_PROJECT_KEY")
    active_model_version = dataiku.api_client().get_project(project_key)\
                                               .get_saved_model(saved_model_id)\
                                               .get_active_version()

    dip_home = dataiku.core.base.get_dip_home()
    model_folder = os.path.join(dip_home, "saved_models", project_key,
                                saved_model_id, "versions",
                                active_model_version["id"])
    model_location = os.path.join(model_folder, constants.KERAS_MODEL_FILENAME)

    if not os.path.isfile(model_location):
        print model_location
        raise ValueError("No model found for this saved model.")

    return model_location
Пример #10
0
def get_dip_home():
    return remoterun.get_env_var('DIP_HOME')
Пример #11
0
    def initialize(self):
        with open(self.connection_file, 'r') as f:
            local_connection_file = json.loads(f.read())

        # start the forwarding (zmq-wise), ie relaying the sockets in the connection file
        port_pairs = []
        for port_type in [
                'shell_port', 'iopub_port', 'stdin_port', 'control_port',
                'hb_port'
        ]:
            local_port = local_connection_file.get(port_type, None)
            if local_port is None or local_port == 0:
                continue
            remote_port = None  # means bind to random
            port_pairs.append([local_port, remote_port, port_type[:-5]])

        def printout(m):
            logging.info(m)

        # bind on 127.0.0.1 for the jupyter-server-facing side and on all interfaces for the kernel-facing side
        def forward_ROUTER_DEALER(local_port, remote_port, port_type):
            return ROUTER_DEALER_Forwarder('127.0.0.1', local_port, '0.0.0.0',
                                           remote_port, port_type, printout,
                                           True, True)

        def forward_PUB_SUB(local_port, remote_port, port_type):
            return PUB_SUB_Forwarder('127.0.0.1', local_port, '0.0.0.0',
                                     remote_port, port_type, printout, True,
                                     True)

        def forward_REP_REQ(local_port, remote_port, port_type):
            return REQ_REP_Forwarder('127.0.0.1', local_port, '0.0.0.0',
                                     remote_port, port_type, printout, True,
                                     True)

        socket_forwarders = {
            'hb': forward_REP_REQ,
            'shell': forward_ROUTER_DEALER,
            'iopub': forward_PUB_SUB,
            'stdin': forward_ROUTER_DEALER,
            'control': forward_ROUTER_DEALER
        }

        for port_pair in port_pairs:
            local_port = port_pair[0]
            remote_port = port_pair[1]
            port_type = port_pair[2]
            logging.info("Relay port %s to %s on type %s" %
                         (local_port, remote_port, port_type))

            socket_forwarder = socket_forwarders[port_type](local_port,
                                                            remote_port,
                                                            port_type)

            port_pair[
                1] = socket_forwarder.remote_port  # retrieve what has been bound

        # swap the ports that the jupyter server knows, and that this forwarder now handles, for
        # the ports it opened for listening for the remote kernel
        for port_pair in port_pairs:
            local_connection_file['%s_port' % port_pair[2]] = port_pair[1]

        # and open a new socket for the comm to the remote kernel overseer (ie runner.py in the container)
        context = zmq.Context()
        self.callback_socket = context.socket(zmq.REP)
        callback_port_selected = self.callback_socket.bind_to_random_port(
            'tcp://*', min_port=10000, max_port=30000, max_tries=100)
        local_connection_file['relayPort'] = callback_port_selected
        self.signaling_socket = context.socket(zmq.PUB)
        signal_port_selected = self.signaling_socket.bind_to_random_port(
            'tcp://*', min_port=10000, max_port=30000, max_tries=100)
        local_connection_file['signalPort'] = signal_port_selected

        remote_kernel = backend_json_call(
            "jupyter/start-remote-kernel",
            data={
                "contextProjectKey":
                remoterun.get_env_var("DKU_CURRENT_PROJECT_KEY"),
                "connectionFile":
                json.dumps(local_connection_file),
                "remoteKernelType":
                self.remote_kernel_type,
                "projectKey":
                self.project_key,
                "bundleId":
                self.bundle_id,
                "envLang":
                self.env_lang,
                "envName":
                self.env_name,
                "containerConf":
                self.container_conf
            })

        logging.info("Started, got : %s" % json.dumps(remote_kernel))
        self.batch_id = remote_kernel['id']

        # start the thread that polls the backend-side thread, to kill this process whenever that thread dies
        # this has to be started before we block on the remote kernel ACK
        self.start_wait_for_remote_kernel_death()

        # block until the remote end has started its kernel
        message = self.callback_socket.recv()
        logging.info("Got %s" % message)
        self.callback_socket.send('ok'.encode('utf8'))

        # start the heartbeating
        hb_thread = threading.Thread(name="forwarder-watcher",
                                     target=self.hb_forwarder)
        hb_thread.daemon = True
        hb_thread.start()

        def caught_sigint(signum, frame):
            print('Signal handler called with signal %s' % signum)
            self.signaling_socket.send('sigint'.encode('utf8'))

        signal.signal(signal.SIGINT, caught_sigint)