Example #1
0
 def _deserialize_dict(self, dict_values):
     deserialized_dict = dict()
     for key, val in dict_values.items():
         if val is None:
             deserialized_dict[key] = None
         elif key == EstimatorParams.model.name:
             deserialize = deserialize_fn()
             deserialized_dict[key] = deserialize(val)
         else:
             deserialized_dict[key] = codec.loads_base64(val)
     return deserialized_dict
Example #2
0
        def _param_deserializer_fn(name, param_val, keras_utils,
                                   custom_objects):
            if param_val is None:
                return param_val

            if name == EstimatorParams.model.name:

                def load_model_fn(x):
                    with keras_utils.keras().utils.custom_object_scope(
                            custom_objects):
                        return keras_utils.keras().models.load_model(
                            x, compile=True)

                return keras_utils.deserialize_model(
                    param_val, load_model_fn=load_model_fn)
            elif name == KerasEstimator.optimizer.name:
                opt_base64_encoded = codec.loads_base64(param_val)
                return keras_utils.deserialize_optimizer(opt_base64_encoded)
            else:
                return codec.loads_base64(param_val)
Example #3
0
def read_data_from_kvstore(addr, port, scope, key):
    try:
        url = "http://{addr}:{port}/{scope}/{key}".format(addr=addr,
                                                          port=str(port),
                                                          scope=scope,
                                                          key=key)
        req = Request(url)
        resp = urlopen(req)
        # TODO: remove base64 encoding because base64 is not efficient
        return codec.loads_base64(resp.read())
    except (HTTPError, URLError) as e:
        raise RuntimeError("Read data from KVStore server failed.", e)
Example #4
0
    def _deserialize(model_bytes_base64):
        """Deserialize model from byte array encoded in base 64."""
        if is_module_available('torch'):
            import torch
            sys.modules["torch._C._nn"] = torch.nn.functional

        obj = codec.loads_base64(model_bytes_base64)

        if not isinstance(obj, torch.nn.Module):
            obj.seek(0)
            bio = io.BytesIO(obj.read())
            obj = torch.jit.load(bio)

        return obj
Example #5
0
 def deserialize_keras_model(model_bytes, load_model_fn):
     """Deserialize model from byte array encoded in base 64."""
     model_bytes = codec.loads_base64(model_bytes)
     bio = io.BytesIO(model_bytes)
     with h5py.File(bio, 'r') as f:
         return load_model_fn(f)
Example #6
0
                                                      next_task_addresses,
                                                      settings.key,
                                                      settings.verbose,
                                                      match_intf=True,
                                                      attempts=10)
        driver.register_task_to_task_addresses(next_task_index,
                                               next_task.addresses())
        # Notify the next task that the address checks are completed.
        next_task.task_to_task_address_check_completed()
        # Wait to get a notification from previous task that its address checks
        # are completed as well.
        task.wait_for_task_to_task_address_check_finish_signal(
            settings.start_timeout)

    finally:
        task.shutdown()


if __name__ == '__main__':
    if len(sys.argv) != 5:
        print('Usage: {} <index> <num_hosts> <driver_addresses> <settings>'.
              format(sys.argv[0]))
        sys.exit(1)

    index = codec.loads_base64(sys.argv[1])
    num_hosts = codec.loads_base64(sys.argv[2])
    driver_addresses = codec.loads_base64(sys.argv[3])
    settings = codec.loads_base64(sys.argv[4])

    _task_fn(index, num_hosts, driver_addresses, settings)
Example #7
0
if __name__ == '__main__':
    """
    Method run by MPI to connect to a host hash and execute the given command.

    The command is usually `orted` to setup the MPI cluster. That `orted` process
    is then used to spin-up the actual remote process, the Horovod user's Python method.
    The `orted` process will run on the lowest task index and all other tasks with the
    same host hash are expected to no-op (see `horovod.spark._task_fn`)
    and wait for the first task to terminate.

    :param driver_addresses: all IP addresses of the driver, base64 encoded
    :param settings: all settings, base64 encoded
    :param host_hash: the host hash to connect to
    :param command: the command and arguments to execute remotely
    """
    if len(sys.argv) < 5:
        print('Usage: %s <service addresses> <settings> <host hash> '
              '<command...>' % sys.argv[0])
        sys.exit(1)

    addresses = codec.loads_base64(sys.argv[1])
    key = codec.loads_base64(os.environ.get(secret.HOROVOD_SECRET_KEY))
    settings = codec.loads_base64(sys.argv[2])
    host_hash = sys.argv[3]
    command = " ".join(sys.argv[4:])
    env = {}  # orted does not need any env vars, the target training code gets env from mpirun

    # Since tasks with the same host hash have shared memory,
    # we will run only one orted process on the first task.
    rsh(addresses, key, host_hash, command, env, 0, settings.verbose)
Example #8
0
    # prepend HOROVOD_SPARK_PYTHONPATH to PYTHONPATH
    if 'HOROVOD_SPARK_PYTHONPATH' in os.environ:
        ppath = os.environ['HOROVOD_SPARK_PYTHONPATH']

        # add injected HOROVOD_SPARK_PYTHONPATH to sys.path
        for p in reversed(ppath.split(os.pathsep)):
            sys.path.insert(1, p)  # don't put it in front which is usually .

        if 'PYTHONPATH' in os.environ:
            ppath = os.pathsep.join([ppath, os.environ['PYTHONPATH']])
        os.environ['PYTHONPATH'] = ppath

    # change current working dir to where the Spark worker runs
    # because orted runs this script where mpirun was executed
    # this env var is injected by the Spark task service
    work_dir = os.environ.get('HOROVOD_SPARK_WORK_DIR')
    if work_dir:
        if settings.verbose >= 2:
            print("Changing cwd from {} to {}".format(os.getcwd(), work_dir))
        os.chdir(work_dir)

    task_exec(driver_addresses, settings, 'OMPI_COMM_WORLD_RANK',
              'OMPI_COMM_WORLD_LOCAL_RANK')


if __name__ == '__main__':
    if len(sys.argv) != 3:
        print('Usage: %s <driver addresses> <settings>' % sys.argv[0])
        sys.exit(1)
    main(codec.loads_base64(sys.argv[1]), codec.loads_base64(sys.argv[2]))