Ejemplo n.º 1
0
def configure_and_create_session(distribution_strategy):
    """Configure session config and create a session with it."""
    # TODO(priyag): Throw error if a session already exists.
    session_config = K.get_default_session_config()

    if is_tpu_strategy(distribution_strategy):
        # TODO(priyag, yuefengz): Remove this workaround when Distribute
        # Coordinator is integrated with keras and we can create a session from
        # there.
        distribution_strategy.configure(session_config)
        master = distribution_strategy.extended._tpu_cluster_resolver.master()  # pylint: disable=protected-access
        session = session_module.Session(config=session_config, target=master)
    else:
        worker_context = dc_context.get_current_worker_context()
        if worker_context:
            dc_session_config = worker_context.session_config
            # Merge the default session config to the one from distribute coordinator,
            # which is fine for now since they don't have conflicting configurations.
            dc_session_config.MergeFrom(session_config)
            session = session_module.Session(
                config=dc_session_config, target=worker_context.master_target)
        else:
            distribution_strategy.configure(session_config)
            session = session_module.Session(config=session_config)

    K.set_session(session)
def configure_and_create_session(distribution_strategy):
  """Configure session config and create a session with it."""
  # TODO(priyag): Throw error if a session already exists.
  session_config = K.get_default_session_config()

  if is_tpu_strategy(distribution_strategy):
    # TODO(priyag, yuefengz): Remove this workaround when Distribute
    # Coordinator is integrated with keras and we can create a session from
    # there.
    distribution_strategy.configure(session_config)
    master = distribution_strategy.extended._tpu_cluster_resolver.master()  # pylint: disable=protected-access
    session = session_module.Session(config=session_config, target=master)
  else:
    worker_context = dc_context.get_current_worker_context()
    if worker_context:
      dc_session_config = worker_context.session_config
      # Merge the default session config to the one from distribute coordinator,
      # which is fine for now since they don't have conflicting configurations.
      dc_session_config.MergeFrom(session_config)
      session = session_module.Session(
          config=dc_session_config, target=worker_context.master_target)
    else:
      session = session_module.Session(config=session_config)

  K.set_session(session)
def configure_and_create_session(distribution_strategy):
    """Configure session config and create a session with it."""
    # TODO(priyag): Throw error if a session already exists.
    session_config = K.get_default_session_config()
    distribution_strategy.configure(session_config)

    if distribution_strategy.__class__.__name__ == 'TPUStrategy':
        # TODO(priyag): Remove this workaround when Distributed Coordinator is
        # integrated with keras and we can create a session from there.
        master = distribution_strategy._tpu_cluster_resolver.master()  # pylint: disable=protected-access
        session = session_module.Session(config=session_config, target=master)
    else:
        session = session_module.Session(config=session_config)

    K.set_session(session)
def configure_and_create_session(distribution_strategy):
  """Configure session config and create a session with it."""
  # TODO(priyag): Throw error if a session already exists.
  session_config = K.get_default_session_config()
  distribution_strategy.configure(session_config)

  if distribution_strategy.__class__.__name__ == 'TPUStrategy':
    # TODO(priyag): Remove this workaround when Distributed Coordinator is
    # integrated with keras and we can create a session from there.
    master = distribution_strategy._tpu_cluster_resolver.master()  # pylint: disable=protected-access
    session = session_module.Session(config=session_config, target=master)
  else:
    session = session_module.Session(config=session_config)

  K.set_session(session)