コード例 #1
0
ファイル: estimator.py プロジェクト: zubrabubra/tensorflow
def _get_replica_device_setter(config):
    """Creates a replica device setter if required as a default device_fn.

  `Estimator` uses ReplicaDeviceSetter as a default device placer. It sets the
  distributed related arguments such as number of ps_replicas based on given
  config.

  Args:
    config: A `RunConfig` instance.

  Returns:
    A replica device setter, or None.
  """
    ps_ops = [
        'Variable', 'VariableV2', 'AutoReloadVariable', 'MutableHashTable',
        'MutableHashTableOfTensors', 'MutableDenseHashTable'
    ]

    if config.task_type:
        worker_device = '/job:%s/task:%d' % (config.task_type, config.task_id)
    else:
        worker_device = '/job:worker'

    if config.num_ps_replicas > 0:
        return training.replica_device_setter(ps_tasks=config.num_ps_replicas,
                                              worker_device=worker_device,
                                              merge_devices=True,
                                              ps_ops=ps_ops,
                                              cluster=config.cluster_spec)
    else:
        return None
コード例 #2
0
ファイル: estimator.py プロジェクト: ilya-edrenkin/tensorflow
def _get_replica_device_setter(config):
  """Creates a replica device setter if required as a default device_fn.

  `Estimator` uses ReplicaDeviceSetter as a default device placer. It sets the
  distributed related arguments such as number of ps_replicas based on given
  config.

  Args:
    config: A `RunConfig` instance.

  Returns:
    A replica device setter, or None.
  """
  ps_ops = [
      'Variable', 'VariableV2', 'AutoReloadVariable', 'MutableHashTable',
      'MutableHashTableV2', 'MutableHashTableOfTensors',
      'MutableHashTableOfTensorsV2', 'MutableDenseHashTable',
      'MutableDenseHashTableV2'
  ]

  if config.task_type:
    worker_device = '/job:%s/task:%d' % (config.task_type, config.task_id)
  else:
    worker_device = '/job:worker'

  if config.num_ps_replicas > 0:
    return training.replica_device_setter(
        ps_tasks=config.num_ps_replicas,
        worker_device=worker_device,
        merge_devices=True,
        ps_ops=ps_ops,
        cluster=config.cluster_spec)
  else:
    return None
コード例 #3
0
def _get_replica_device_setter(config):
  """Creates a replica device setter if required as a default device_fn.

  `Estimator` uses ReplicaDeviceSetter as a default device placer. It sets the
  distributed related arguments such as number of ps_replicas based on given
  config.

  Args:
    config: A `RunConfig` instance.

  Returns:
    A replica device setter, or None.
  """
  if config.task_type:
    worker_device = '/job:%s/task:%d' % (config.task_type, config.task_id)
  else:
    worker_device = '/job:worker'

  if config.num_ps_replicas > 0:
    return training.replica_device_setter(
        ps_tasks=config.num_ps_replicas,
        worker_device=worker_device,
        merge_devices=True,
        ps_ops=list(device_setter.STANDARD_PS_OPS),
        cluster=config.cluster_spec)
  else:
    return None