Example #1
0
  def testRemote(self):
    gpus = config.list_logical_devices('GPU')
    self.assertNotEqual(len(gpus), 0)

    context.ensure_initialized()

    gpus = config.list_logical_devices('GPU')
    self.assertNotEqual(len(gpus), 0)
    for gpu in gpus:
      self.assertIsNotNone(gpu.name)

    context.ensure_initialized()

    job_name = 'test'
    cluster_def = cluster_pb2.ClusterDef()
    job_def = cluster_def.job.add()
    job_def.name = job_name
    job_def.tasks[0] = 'localhost:0'

    server_def = tensorflow_server_pb2.ServerDef(
        cluster=cluster_def, job_name=job_name, task_index=0, protocol='grpc')

    context.set_server_def(server_def)

    gpus = config.list_logical_devices('GPU')
    for gpu in gpus:
      self.assertIsNotNone(gpu.name)
Example #2
0
  def testBadConstructorArgs(self):
    context.ensure_initialized()
    ctx = context.context()
    handle = ctx._handle
    device = ctx.device_name
    # Missing context.
    with self.assertRaisesRegexp(
        TypeError, r".*argument 'context' \(pos 2\).*"):
      ops.EagerTensor(1, device=device)
    # Missing device.
    with self.assertRaisesRegexp(
        TypeError, r".*argument 'device' \(pos 3\).*"):
      ops.EagerTensor(1, context=handle)
    # Bad dtype type.
    with self.assertRaisesRegexp(TypeError,
                                 "Expecting a DataType value for dtype. Got"):
      ops.EagerTensor(1, context=handle, device=device, dtype="1")

    # Following errors happen when trying to copy to GPU.
    if not test_util.is_gpu_available():
      self.skipTest("No GPUs found")

    with ops.device("/device:GPU:0"):
      device = ctx.device_name
      # Bad context.
      with self.assertRaisesRegexp(
          TypeError, "Expecting a PyCapsule encoded context handle. Got"):
        ops.EagerTensor(1.0, context=1, device=device)
      # Bad device.
      with self.assertRaisesRegexp(
          TypeError, "Error parsing device argument to CopyToDevice"):
        ops.EagerTensor(1.0, context=handle, device=1)
Example #3
0
  def testCpuMultiple(self):
    cpus = config.list_physical_devices('CPU')
    self.assertEqual(len(cpus), 1)

    config.set_virtual_device_configuration(cpus[0], [
        context.VirtualDeviceConfiguration(),
        context.VirtualDeviceConfiguration()
    ])

    context.ensure_initialized()

    cpus = config.list_logical_devices('CPU')
    self.assertEqual(len(cpus), 2)

    with ops.device('/device:CPU:0'):
      a = constant_op.constant(1.0)
      self.evaluate(a)
    with ops.device('/device:CPU:1'):
      b = constant_op.constant(1.0)
      self.evaluate(b)
    with self.assertRaisesRegexp(RuntimeError, 'unknown device'):
      with ops.device('/device:CPU:2'):
        c = constant_op.constant(1.0)
        self.evaluate(c)

    # Ensure we can place ops on each of the device names
    for cpu in cpus:
      with ops.device(cpu.name):
        d = constant_op.constant(1.0)
        self.evaluate(d)
Example #4
0
def op_attr_type(op_type, attr_name):
  try:
    return _op_attr_type_cache[(op_type, attr_name)]
  except KeyError:
    context.ensure_initialized()
    h = context.context()._handle  # pylint: disable=protected-access
    attr_type = pywrap_tensorflow.TFE_OpNameGetAttrType(h, op_type, attr_name)
  _op_attr_type_cache[(op_type, attr_name)] = attr_type
  return attr_type
Example #5
0
  def __init__(self, dist, coord, replica_id, device_map, variable_creator_fn,
               fn, args, kwargs):
    super(_MirroredReplicaThread, self).__init__()
    self.coord = coord
    self.distribution = dist
    self.device_map = device_map
    self.replica_id = replica_id
    self.variable_creator_fn = variable_creator_fn
    # State needed to run and return the results of `fn`.
    self.main_fn = fn
    self.main_args = args
    self.main_kwargs = kwargs
    self.main_result = None
    self.done = False
    # State needed to run the next merge_call() (if any) requested via
    # ReplicaContext.
    self.merge_fn = None
    self.merge_args = None
    self.merge_kwargs = None
    self.merge_result = None
    self.captured_name_scope = None
    self.captured_var_scope = None
    # We use a thread.Event for the main thread to signal when this
    # thread should start running (`should_run`), and another for
    # this thread to transfer control back to the main thread
    # (`has_paused`, either when it gets to a
    # `get_replica_context().merge_call` or when `fn` returns). In
    # either case the event starts cleared, is signaled by calling
    # set(). The receiving thread waits for the signal by calling
    # wait() and then immediately clearing the event using clear().
    self.should_run = threading.Event()
    self.has_paused = threading.Event()
    # These fields have to do with inheriting various contexts from the
    # parent thread:
    context.ensure_initialized()
    ctx = context.context()
    self.in_eager = ctx.executing_eagerly()
    self.record_thread_local_context_fields()
    self.context_device_policy = (
        pywrap_tensorflow.TFE_ContextGetDevicePlacementPolicy(
            ctx._context_handle))
    self.graph = ops.get_default_graph()
    with ops.init_scope():
      self._init_in_eager = context.executing_eagerly()
      self._init_graph = ops.get_default_graph()

    self._variable_creator_stack = self.graph._variable_creator_stack[:]
    self._var_scope = variable_scope.get_variable_scope()
    # Adding a "/" at end lets us re-enter this scope later.
    self._name_scope = self.graph.get_name_scope()
    if self._name_scope:
      self._name_scope += "/"
    if self.replica_id > 0:
      if not self._name_scope:
        self._name_scope = ""
      self._name_scope += "replica_%d/" % self.replica_id
Example #6
0
  def testInterOpParallelismThreads(self):
    config.set_inter_op_parallelism_threads(10)
    self.assertEqual(
        config.get_inter_op_parallelism_threads(),
        context.context().inter_op_parallelism_threads)

    context.ensure_initialized()

    with self.assertRaises(RuntimeError):
      config.set_inter_op_parallelism_threads(1)
 def __init__(self, name=None, shared_name=None,
              critical_section_def=None, import_scope=None):
   """Creates a critical section."""
   context.ensure_initialized()
   if critical_section_def and name is not None:
     raise ValueError("critical_section_def and shared_name are "
                      "mutually exclusive.")
   if critical_section_def:
     self._init_from_proto(critical_section_def, import_scope=import_scope)
   else:
     self._init_from_args(name, shared_name)
Example #8
0
def _create_tensor(value, device=None, dtype=None):
  context.ensure_initialized()
  ctx = context.context()
  if device is None:
    device = ctx.device_name
  if dtype is not None:
    dtype = dtype.as_datatype_enum
  try:
    return ops.EagerTensor(
        value, context=ctx._handle, device=device, dtype=dtype)
  except core._NotOkStatusException as e:  # pylint: disable=protected-access
    raise core._status_to_exception(e.code, e.message)
Example #9
0
  def __init__(self, func, Tout, is_grad_func):
    """Constructs an EagerFunc.

    Args:
      func: The function to wrap.
      Tout: A list of datatypes for the output; an empty list if the output is
        None.
      is_grad_func: Whether this EagerFunc is the gradient of another
        EagerPyFunc.
    """
    self._func = func
    self._out_dtypes = Tout
    self._is_grad_func = is_grad_func

    context.ensure_initialized()
Example #10
0
  def testDevicePolicy(self):
    self.assertEqual(context.DEVICE_PLACEMENT_SILENT,
                     context.context().device_policy)

    # If no op has been executed we should be able to set the device policy as
    # well as any init-time configs.
    config.set_intra_op_parallelism_threads(1)
    config.set_device_policy('silent')
    config.set_intra_op_parallelism_threads(2)

    context.ensure_initialized()

    def copy_tensor(dtype=dtypes.int32):
      cpu_tensor = constant_op.constant(1, dtype=dtype)
      gpu_tensor = cpu_tensor.gpu()
      self.assertAllEqual(cpu_tensor + gpu_tensor, 2.0)

    config.set_device_policy('silent')
    self.assertEqual(config.get_device_policy(), 'silent')
    self.assertEqual(context.DEVICE_PLACEMENT_SILENT,
                     context.context().device_policy)
    copy_tensor()

    config.set_device_policy('silent_for_int32')
    self.assertEqual(config.get_device_policy(), 'silent_for_int32')
    self.assertEqual(context.DEVICE_PLACEMENT_SILENT_FOR_INT32,
                     context.context().device_policy)
    with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                 'Tensors on conflicting devices'):
      copy_tensor(dtypes.float32)
    copy_tensor()

    config.set_device_policy('warn')
    self.assertEqual(config.get_device_policy(), 'warn')
    self.assertEqual(context.DEVICE_PLACEMENT_WARN,
                     context.context().device_policy)
    copy_tensor()

    config.set_device_policy('explicit')
    self.assertEqual(config.get_device_policy(), 'explicit')
    self.assertEqual(context.DEVICE_PLACEMENT_EXPLICIT,
                     context.context().device_policy)
    with self.assertRaisesRegexp(errors.InvalidArgumentError,
                                 'Tensors on conflicting devices'):
      copy_tensor()

    config.set_device_policy(None)
    self.assertEqual(config.get_device_policy(), 'silent')
Example #11
0
  def testGpuMultiple(self):
    gpus = config.list_physical_devices('GPU')
    if len(gpus) < 2:
      self.skipTest('Need at least 2 GPUs')

    context.ensure_initialized()

    for i in range(0, len(gpus)):
      with ops.device('/device:GPU:' + str(i)):
        a = constant_op.constant(1.0)
        self.evaluate(a)

    with self.assertRaisesRegex(RuntimeError, 'unknown device'):
      with ops.device('/device:GPU:' + str(len(gpus))):
        a = constant_op.constant(1.0)
        self.evaluate(a)
Example #12
0
  def testGpuMultiple(self):
    gpus = config.list_physical_devices('GPU')
    if len(gpus) < 2:
      self.skipTest('Need at least 2 GPUs')

    context.ensure_initialized()

    for i in range(0, len(gpus)):
      with ops.device('/device:GPU:' + str(i)):
        a = constant_op.constant(1.0)
        self.evaluate(a)

    with self.assertRaisesRegexp(RuntimeError, 'unknown device'):
      with ops.device('/device:GPU:' + str(len(gpus))):
        a = constant_op.constant(1.0)
        self.evaluate(a)
def start_profiler_server(port):
    """Start a profiler grpc server that listens to given port.

  The profiler server will keep the program running even the training finishes.
  Please shutdown the server with CTRL-C. It can be used in both eager mode and
  graph mode. The service defined in
  tensorflow/core/profiler/profiler_service.proto. Please use
  tensorflow/contrib/tpu/profiler/capture_tpu_profile to capture tracable
  file following https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_trace

  Args:
    port: port profiler server listens to.
  """
    if context.default_execution_mode == context.EAGER_MODE:
        context.ensure_initialized()
    pywrap_tensorflow.TFE_StartProfilerServer(port)
Example #14
0
    def testCpuMultiple(self):
        cpus = config.list_physical_devices('CPU')
        self.assertEqual(len(cpus), 1)

        config.set_logical_device_configuration(cpus[0], [
            context.LogicalDeviceConfiguration(),
            context.LogicalDeviceConfiguration()
        ])

        context.ensure_initialized()

        vcpus = config.list_logical_devices('CPU')
        self.assertEqual(len(vcpus), 2)

        with ops.device('/device:CPU:0'):
            a = constant_op.constant(1.0)
            self.evaluate(a)
        with ops.device('/device:CPU:1'):
            b = constant_op.constant(1.0)
            self.evaluate(b)
        with ops.device('/device:CPU:2'):
            c = constant_op.constant(1.0)
            self.evaluate(c)
        if test_util.is_gpu_available():
            self.assertIn('GPU:0', c.device)
        else:
            self.assertIn('CPU:0', c.device)

        # Ensure we can place ops on each of the device names
        for vcpu in vcpus:
            with ops.device(vcpu.name):
                d = constant_op.constant(1.0)
                self.evaluate(d)

        # Modifying the CPU configuration is not supported
        with self.assertRaisesRegex(RuntimeError, 'cannot be modified'):
            config.set_logical_device_configuration(cpus[0], [
                context.LogicalDeviceConfiguration(),
                context.LogicalDeviceConfiguration(),
                context.LogicalDeviceConfiguration()
            ])

        # Setting the same CPU configuration is fine
        config.set_logical_device_configuration(cpus[0], [
            context.LogicalDeviceConfiguration(),
            context.LogicalDeviceConfiguration()
        ])
 def __init__(self,
              name=None,
              shared_name=None,
              critical_section_def=None,
              import_scope=None):
     """Creates a critical section."""
     context.ensure_initialized()
     if critical_section_def and name is not None:
         raise ValueError(
             f"Arguments critical_section_def={critical_section_def} "
             f"and shared_name={shared_name} are mutually exclusive. "
             "Please only specify one of them.")
     if critical_section_def:
         raise ValueError(
             "Argument `critical_section_def` is not supported.")
     else:
         self._init_from_args(name, shared_name)
Example #16
0
    def testGpuPerProcessMemoryGrowth(self):
        self.assertFalse(config.get_gpu_per_process_memory_growth())

        config.set_gpu_per_process_memory_growth(True)
        self.assertTrue(config.get_gpu_per_process_memory_growth())
        self.assertEqual(config.get_gpu_per_process_memory_growth(),
                         context.context().gpu_per_process_memory_growth)

        config.set_gpu_per_process_memory_growth(False)
        self.assertFalse(config.get_gpu_per_process_memory_growth())
        self.assertEqual(config.get_gpu_per_process_memory_growth(),
                         context.context().gpu_per_process_memory_growth)

        context.ensure_initialized()

        with self.assertRaises(RuntimeError):
            config.set_gpu_per_process_memory_growth(True)
Example #17
0
    def testLogDevicePlacement(self):
        self.assertFalse(context.get_log_device_placement())

        context.set_log_device_placement(True)
        self.assertEqual(context.get_log_device_placement(), True)
        self.assertEqual(context.get_log_device_placement(),
                         context.context().log_device_placement)

        context.set_log_device_placement(False)
        self.assertEqual(context.get_log_device_placement(), False)
        self.assertEqual(context.get_log_device_placement(),
                         context.context().log_device_placement)

        context.ensure_initialized()

        # Changing the device placement should not throw an exception
        context.set_log_device_placement(True)
Example #18
0
  def test_expand_distributed_variables(self, expand_strategy):
    context._reset_context()
    cpus = context.context().list_physical_devices("CPU")
    if len(cpus) == 1:
      context.context().set_logical_device_configuration(
          cpus[0], [
              context.LogicalDeviceConfiguration(),
              context.LogicalDeviceConfiguration()
          ])
    context.ensure_initialized()

    file_name = os.path.join(self.get_temp_dir(), "saved_model.pb")
    with mirrored_strategy.MirroredStrategy(["CPU:0", "CPU:1"]).scope():
      root = tracking.AutoTrackable()
      root.v = variables.Variable([1., 1.], name="v")

      @def_function.function(input_signature=[])
      def f():
        root.v.assign([2., 2.])

      root.f = f

      save.export_meta_graph(
          obj=root,
          filename=file_name,
          options=save_options.SaveOptions(
              experimental_variable_policy=expand_strategy))
    graph_def = meta_graph.read_meta_graph_file(file_name).graph_def
    v0 = next((n for n in graph_def.node if n.name == "v"), None)
    v1 = next((n for n in graph_def.node if n.name == "v/replica_1"), None)
    self.assertIsNotNone(v0)
    saved_function = next((f for f in graph_def.library.function
                           if "inference_f_" in f.signature.name), None)
    self.assertIsNotNone(saved_function)
    if (expand_strategy ==
        save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES):
      self.assertIsNotNone(v1)
      # experimental_save_variable_devices should have been automatically set.
      self.assertIn("CPU:0", v0.device)
      self.assertIn("CPU:1", v1.device)
      self.assertLen(saved_function.signature.input_arg, 2)
    else:
      self.assertIsNone(v1)
      self.assertEmpty(v0.device)
      # TODO(b/159752793): There should be only one input here.
      self.assertLen(saved_function.signature.input_arg, 2)
    def testCollectiveReduceMinMax(self):
        gpus = config.list_physical_devices('GPU')
        if len(gpus) != 1:
            self.skipTest('Expected 1 GPU but found {} GPUs'.format(len(gpus)))
        config.set_virtual_device_configuration(gpus[0], [
            context.VirtualDeviceConfiguration(1024),
            context.VirtualDeviceConfiguration(1024)
        ])
        context.ensure_initialized()

        @def_function.function
        def run_all_reduce(group_key, instance_key, merge_op):
            group_size = 2
            t0 = [1., 20., 3., 40., 5.]
            t1 = [10., 2., 30., 4., 50.]
            os.environ['NCCL_DEBUG'] = 'INFO'
            os.environ['NCCL_LAUNCH_MODE'] = 'PARALLEL'
            with ops.device('/GPU:0'):
                in0 = constant_op.constant(t0)
                c0 = collective_ops.all_reduce(in0,
                                               group_size,
                                               group_key,
                                               instance_key,
                                               merge_op,
                                               final_op='Id',
                                               communication_hint='nccl')
            with ops.device('/GPU:1'):
                in1 = constant_op.constant(t1)
                c1 = collective_ops.all_reduce(in1,
                                               group_size,
                                               group_key,
                                               instance_key,
                                               merge_op,
                                               final_op='Id',
                                               communication_hint='nccl')
            return c0, c1

        for combination in [('Max', [10., 20., 30., 40., 50.]),
                            ('Min', [1., 2., 3., 4., 5.])]:
            merge_op = combination[0]
            results = run_all_reduce(group_key=10,
                                     instance_key=20,
                                     merge_op=merge_op)
            expected = combination[1]
            for result in results:
                self.assertAllClose(result, expected, rtol=1e-5, atol=1e-5)
Example #20
0
def dtensor_initialize_multi_client(
        enable_coordination_service: Optional[bool] = False) -> None:
    """Initializes Multi Client DTensor.

  The following environment variables controls the behavior of this function.
  If the variables are unset, DTensor will be configured to run in single-client
  mode.

  - DTENSOR_CLIENT_ID: integer, between 0 to num_clients - 1, to identify the
      client id of the current process.
  - DTENSOR_NUM_CLIENTS: integer, the number of clients.
  - DTENSOR_JOB_NAME: string, a hostname like string for the name of the dtensor
      job. The job name is used by TensorFlow in the job name section of
      the DeviceSpec.
  - DTENSOR_JOBS: string, a comma separated list. Each item in the list is
      of format `{hostname}:{port}` and the items must be sorted in alphabet
      order. The implication is the RPC port numbers of the clients from
      the same host must be ordered by the client ID.
      Examples of valid DTENSOR_JOBS values:
      - 4 clients on localhost:
        `localhost:10000,localhost:10001,localhost:10002,localhost:10003`
      - 2 clients on host1, 2 clients on host2
        `host1:10000,host1:10001,host2:10000,host2:10003`

  Args:
    enable_coordination_service: If true, enable distributed coordination
      service to make sure that workers know the devices on each other, a
      prerequisite for data transfer through cross-worker rendezvous.
  """
    assert context.executing_eagerly()

    # Collective GRPC servers are only necessary in multi-client setup.
    # Single clients can use local mode of collectives.
    if api.num_clients() > 1:
        multi_client_util.initialize_multi_client_cluster(
            job_name=api.job_name(),
            dtensor_jobs=api.jobs(),
            client_id=api.client_id(),
            collective_leader=api.full_job_name(task_id=0),
            enable_coordination_service=enable_coordination_service)

    # Make sure the server change is fully propagated before returning.
    context.ensure_initialized()
    context.async_wait()
    context.context()._clear_caches()  # pylint: disable=protected-access
def start():
    """Start profiling.

  Raises:
    ProfilerAlreadyRunningError: If another profiling session is running.
  """
    global _profiler
    with _profiler_lock:
        if _profiler is not None:
            raise ProfilerAlreadyRunningError('Another profiler is running.')
        if context.default_execution_mode == context.EAGER_MODE:
            context.ensure_initialized()
        _profiler = pywrap_tensorflow.TFE_NewProfiler()
        if not pywrap_tensorflow.TFE_ProfilerIsOk(_profiler):
            logging.warning(
                'Another profiler session is running which is probably '
                'created by profiler server. Please avoid using profiler '
                'server and profiler APIs at the same time.')
Example #22
0
  def testParamResolutionAfterTimeoutV2(self):
    context._reset_context()
    timeout = 1.5
    cpus = config.list_physical_devices('CPU')
    self.assertEqual(len(cpus), 1)
    config.set_logical_device_configuration(cpus[0], [
        context.LogicalDeviceConfiguration(),
        context.LogicalDeviceConfiguration()
    ])
    context.ensure_initialized()

    group_key = 20
    instance_key = 30
    input_data = constant_op.constant([1, 2, 3, 4])

    # This timeout comes from param solution.
    with self.assertRaisesRegex(
        errors.DeadlineExceededError,
        'Collective has timed out waiting for other workers'):
      with ops.device('CPU:0'):
        collective_ops.all_reduce(
            input_data,
            group_size=2,
            group_key=group_key,
            instance_key=instance_key,
            merge_op='Add',
            final_op='Id',
            timeout=timeout)

    # We launch the second device after the first device times out. This is to
    # simulate the situation when other workers are slow and the timeout is
    # short. Since the CPU:0 times out in the param resolution phase, CPU:1
    # should times out as well, but in the execute phase.
    with self.assertRaisesRegex(errors.DeadlineExceededError,
                                'Collective has timed out during execution'):
      with ops.device('CPU:1'):
        collective_ops.all_reduce(
            input_data,
            group_size=2,
            group_key=group_key,
            instance_key=instance_key,
            merge_op='Add',
            final_op='Id',
            timeout=timeout)
Example #23
0
    def testLogDevicePlacement(self):
        self.assertFalse(context.get_log_device_placement())

        context.set_log_device_placement(True)
        self.assertEqual(context.get_log_device_placement(), True)
        self.assertEqual(context.get_log_device_placement(),
                         context.context().log_device_placement)

        context.set_log_device_placement(False)
        self.assertEqual(context.get_log_device_placement(), False)
        self.assertEqual(context.get_log_device_placement(),
                         context.context().log_device_placement)

        context.ensure_initialized()

        with self.assertRaises(RuntimeError):
            context.set_log_device_placement(True)
        with self.assertRaises(RuntimeError):
            context.set_log_device_placement(False)
Example #24
0
    def test_save_variable_devices(self, save_devices, meta_graph_only):
        context._reset_context()
        cpus = context.context().list_physical_devices("CPU")
        if len(cpus) == 1:
            context.context().set_logical_device_configuration(
                cpus[0], [
                    context.LogicalDeviceConfiguration(),
                    context.LogicalDeviceConfiguration()
                ])
        context.ensure_initialized()

        root = tracking.AutoTrackable()
        with ops.device("CPU:0"):
            root.v0 = variables.Variable(1., name="v0")
        with ops.device("CPU:1"):
            root.v1 = variables.Variable(1., name="v1")

        options = save_options.SaveOptions(
            experimental_variable_policy=save_devices)
        file_name = os.path.join(self.get_temp_dir(), "saved_model")
        if meta_graph_only:
            save.export_meta_graph(obj=root,
                                   filename=file_name,
                                   options=options)
        else:
            save.save(obj=root, export_dir=file_name, options=options)

        graph_def = None
        if meta_graph_only:
            graph_def = meta_graph.read_meta_graph_file(file_name).graph_def
        else:
            graph_def = loader_impl.parse_saved_model(
                file_name).meta_graphs[0].graph_def
        v0 = next((n for n in graph_def.node if n.name == "v0"), None)
        v1 = next((n for n in graph_def.node if n.name == "v1"), None)
        self.assertIsNotNone(v0)
        self.assertIsNotNone(v1)
        if save_devices == save_options.VariablePolicy.SAVE_VARIABLE_DEVICES:
            self.assertIn("CPU:0", v0.device)
            self.assertIn("CPU:1", v1.device)
        else:
            self.assertEmpty(v0.device)
            self.assertEmpty(v1.device)
    def testCollectiveTimeoutV2(self):
        context._reset_context()
        timeout = 4.5
        cpus = config.list_physical_devices('CPU')
        self.assertEqual(len(cpus), 1)
        config.set_logical_device_configuration(cpus[0], [
            context.LogicalDeviceConfiguration(),
            context.LogicalDeviceConfiguration()
        ])
        context.ensure_initialized()

        @def_function.function
        def run_all_reduce(group_size, reported_group_size=None):
            group_key = 20
            instance_key = 30
            tensor = [1, 2, 3, 4]
            results = []
            if reported_group_size is None:
                reported_group_size = group_size
            for i in range(group_size):
                with ops.device('/CPU:{}'.format(i)):
                    input_data = constant_op.constant(tensor)
                    collective_op = collective_ops.all_reduce(
                        input_data,
                        group_size=reported_group_size,
                        group_key=group_key,
                        instance_key=instance_key,
                        merge_op='Add',
                        final_op='Id',
                        timeout=timeout)
                    results.append(collective_op)
            return results

        run_all_reduce(2, 2)

        start_time = time.time()
        with self.assertRaisesRegex(
                errors.DeadlineExceededError,
                'Collective has timed out during execution'):
            run_all_reduce(1, 2)
        elapsed = time.time() - start_time
        self.assertAllGreaterEqual(elapsed, timeout)
Example #26
0
def start():
  """Start profiling.

  Raises:
    ProfilerAlreadyRunningError: If another profiling session is running.
  """
  global _profiler
  with _profiler_lock:
    if _profiler is not None:
      raise ProfilerAlreadyRunningError('Another profiler is running.')
    if context.default_execution_mode == context.EAGER_MODE:
      context.ensure_initialized()
    _profiler = _pywrap_profiler.ProfilerSession()
    try:
      _profiler.start()
    except errors.AlreadyExistsError:
      logging.warning('Another profiler session is running which is probably '
                      'created by profiler server. Please avoid using profiler '
                      'server and profiler APIs at the same time.')
      raise ProfilerAlreadyRunningError('Another profiler is running.')
Example #27
0
  def testBadConstructorArgs(self):
    context.ensure_initialized()
    ctx = context.context()
    device = ctx.device_name
    # Missing device.
    with self.assertRaisesRegexp(TypeError, r".*argument 'device' \(pos 2\).*"):
      ops.EagerTensor(1)
    # Bad dtype type.
    with self.assertRaisesRegexp(TypeError,
                                 "Expecting a DataType value for dtype. Got"):
      ops.EagerTensor(1, device=device, dtype="1")

    # Following errors happen when trying to copy to GPU.
    if not test_util.is_gpu_available():
      self.skipTest("No GPUs found")

    with ops.device("/device:GPU:0"):
      # Bad device.
      with self.assertRaisesRegexp(TypeError, "Error parsing device argument"):
        ops.EagerTensor(1.0, device=1)
Example #28
0
    def testCollectiveTensorsHaveNoDeviceSpecified(self):
        context._reset_context()
        cpus = config.list_physical_devices('CPU')
        self.assertEqual(len(cpus), 1)
        config.set_logical_device_configuration(cpus[0], [
            context.LogicalDeviceConfiguration(),
            context.LogicalDeviceConfiguration()
        ])
        context.ensure_initialized()

        group_size = 2
        group_key = 1
        instance_key = 1

        @def_function.function
        def fn(all_args):
            results = []
            # The inputs have no devices set. This is expected to be a trace-time
            # check only.
            self.assertEqual(all_args[0].device, '')
            self.assertEqual(all_args[1].device, '')

            with ops.device('/CPU:0'):
                results.append(
                    collective_ops.all_reduce(all_args[0], group_size,
                                              group_key, instance_key, 'Add',
                                              'Div'))
            with ops.device('/CPU:1'):
                results.append(
                    collective_ops.all_reduce(all_args[1], group_size,
                                              group_key, instance_key, 'Add',
                                              'Div'))

            return results

        with ops.device('/CPU:0'):
            in0 = constant_op.constant(1)
        with ops.device('/CPU:1'):
            in1 = constant_op.constant(3)
        result = fn([in0, in1])
        self.assertAllClose(result, [2, 2])
Example #29
0
  def testLogDevicePlacement(self):
    self.assertFalse(context.get_log_device_placement())

    context.set_log_device_placement(True)
    self.assertEqual(context.get_log_device_placement(), True)
    self.assertEqual(
        context.get_log_device_placement(),
        context.context().log_device_placement)

    context.set_log_device_placement(False)
    self.assertEqual(context.get_log_device_placement(), False)
    self.assertEqual(
        context.get_log_device_placement(),
        context.context().log_device_placement)

    context.ensure_initialized()

    with self.assertRaises(RuntimeError):
      context.set_log_device_placement(True)
    with self.assertRaises(RuntimeError):
      context.set_log_device_placement(False)
Example #30
0
    def testLogDevicePlacement(self):
        self.assertFalse(context.get_log_device_placement())

        context.set_log_device_placement(True)
        self.assertEqual(context.get_log_device_placement(), True)
        self.assertEqual(context.get_log_device_placement(),
                         context.context().log_device_placement)

        context.set_log_device_placement(False)
        self.assertEqual(context.get_log_device_placement(), False)
        self.assertEqual(context.get_log_device_placement(),
                         context.context().log_device_placement)

        context.ensure_initialized()

        with self.assertRaises(RuntimeError):
            context.set_log_device_placement(True)

        # If the setting the device placement is a no-op, do not throw a runtime
        # exception.
        context.set_log_device_placement(False)
Example #31
0
def start():
  """Start profiling.

  Raises:
    ProfilerAlreadyRunningError: If another profiling session is running.
  """
  global _profiler
  with _profiler_lock:
    if _profiler is not None:
      raise ProfilerAlreadyRunningError('Another profiler is running.')
    context.ensure_initialized()
    profiler_context = pywrap_tensorflow.TFE_NewProfilerContext()
    if context.default_execution_mode == context.EAGER_MODE:
      pywrap_tensorflow.TFE_ProfilerContextSetEagerContext(
          profiler_context,
          context.context()._handle)  # pylint: disable=protected-access
    _profiler = pywrap_tensorflow.TFE_NewProfiler(profiler_context)
    pywrap_tensorflow.TFE_DeleteProfilerContext(profiler_context)
    if not pywrap_tensorflow.TFE_ProfilerIsOk(_profiler):
      logging.warning('Another profiler session is running which is probably '
                      'created by profiler server. Please avoid using profiler '
                      'server and profiler APIs at the same time.')
Example #32
0
    def __init__(self, persistent=False, watch_accessed_variables=True):
        """Creates a new GradientTape.

    Args:
      persistent: Boolean controlling whether a persistent gradient tape
        is created. False by default, which means at most one call can
        be made to the gradient() method on this object.
      watch_accessed_variables: Boolean controlling whether the tape will
        automatically `watch` any (trainable) variables accessed while the tape
        is active. Defaults to True meaning gradients can be requested from any
        result computed in the tape derived from reading a trainable `Variable`.
        If False users must explicitly `watch` any `Variable`s they want to
        request gradients from.
    """
        self._tape = None
        self._persistent = persistent
        self._watch_accessed_variables = watch_accessed_variables
        self._recording = False
        self._created_eagerly = context.executing_eagerly()
        if self._created_eagerly:
            context.ensure_initialized()
            context.context().start_step()
Example #33
0
    def testOpDefDefaultType(self):
        im = np.random.randint(  # pylint: disable=too-many-function-args
            low=0,
            high=65535,
            size=100,
            dtype=np.uint16).reshape(10, 10, 1)

        context.ensure_initialized()

        fastpath_dtype = test_ops.dtype_with_default_op(im).numpy()
        slowpath_dtype = test_ops.dtype_with_default_op_eager_fallback(
            im, None, context.context()).numpy()
        # Ensure the fastpath and slowpath eager paths work.
        self.assertEqual(fastpath_dtype, slowpath_dtype)

        with ops.Graph().as_default(), self.cached_session():
            graph_dtype_symbolic = test_ops.dtype_with_default_op(im)

            graph_dtype = self.evaluate(graph_dtype_symbolic)
        # Ensure the eager path matches the graph path.
        self.assertEqual(fastpath_dtype, graph_dtype)

        # Unfortunately, as of now, this doesn't work as expected on def_functions,
        # since we convert the numpy arrays to tensors pre-tracing (which won't get
        # overriddent by the default type).
        @def_function.function
        def func(im):
            return test_ops.dtype_with_default_op(im)

        function_dtype = func(im).numpy()
        self.assertNotEqual(fastpath_dtype, function_dtype)

        # Captures are OK, since they don't go through the conversion path.
        @def_function.function
        def func_captured():
            return test_ops.dtype_with_default_op(im)

        function_dtype = func_captured().numpy()
        self.assertEqual(fastpath_dtype, function_dtype)
Example #34
0
  def __init__(self, persistent=False, watch_accessed_variables=True):
    """Creates a new GradientTape.

    Args:
      persistent: Boolean controlling whether a persistent gradient tape
        is created. False by default, which means at most one call can
        be made to the gradient() method on this object.
      watch_accessed_variables: Boolean controlling whether the tape will
        automatically `watch` any (trainable) variables accessed while the tape
        is active. Defaults to True meaning gradients can be requested from any
        result computed in the tape derived from reading a trainable `Variable`.
        If False users must explicitly `watch` any `Variable`s they want to
        request gradients from.
    """
    self._tape = None
    self._persistent = persistent
    self._watch_accessed_variables = watch_accessed_variables
    self._recording = False
    self._created_eagerly = context.executing_eagerly()
    if self._created_eagerly:
      context.ensure_initialized()
      context.context().start_step()
Example #35
0
  def __call__(self, *args, **kwds):
    """Calls the graph function and warn too frequent tracings."""
    context.ensure_initialized()
    if RUN_FUNCTIONS_EAGERLY:
      return self._python_function(*args, **kwds)

    tracing_count = self._get_tracing_count()
    if self._experimental_compile:
      # V2 control flow relies on XLAControlFlowContext to generate a
      # XLA-compatible function graph.
      xla_context = control_flow_ops.XLAControlFlowContext()
      try:
        xla_context.Enter()
        result = self._call(*args, **kwds)
      finally:
        xla_context.Exit()
    else:
      result = self._call(*args, **kwds)

    if tracing_count == self._get_tracing_count():
      self._call_counter.called_without_tracing()
      return result

    self._call_counter.called_with_tracing()
    recent_tracing_count = self._call_counter.get_tracing_count()
    if recent_tracing_count >= FREQUENT_TRACING_WARNING_THRESHOLD:
      logging.warning(
          "{} out of the last {} calls to {} triggered tf.function retracing. "
          "Tracing is expensive and the excessive number of tracings is likely "
          "due to passing python objects instead of tensors. Also, tf.function "
          "has experimental_relax_shapes=True option that relaxes argument "
          "shapes that can avoid unnecessary retracing. Please refer to "
          "https://www.tensorflow.org/tutorials/customization/performance#python_or_tensor_args"
          " and https://www.tensorflow.org/api_docs/python/tf/function for more "
          "details.".format(recent_tracing_count, self._call_counter.call_count,
                            self._python_function))

    return result
Example #36
0
  def testLogDevicePlacement(self):
    self.assertFalse(context.get_log_device_placement())

    context.set_log_device_placement(True)
    self.assertEqual(context.get_log_device_placement(), True)
    self.assertEqual(
        context.get_log_device_placement(),
        context.context().log_device_placement)

    context.set_log_device_placement(False)
    self.assertEqual(context.get_log_device_placement(), False)
    self.assertEqual(
        context.get_log_device_placement(),
        context.context().log_device_placement)

    context.ensure_initialized()

    with self.assertRaises(RuntimeError):
      context.set_log_device_placement(True)

    # If the setting the device placement is a no-op, do not throw a runtime
    # exception.
    context.set_log_device_placement(False)
Example #37
0
def start():
    """Start profiling.

  Raises:
    ProfilerAlreadyRunningError: If another profiling session is running.
  """
    global _profiler
    with _profiler_lock:
        if _profiler is not None:
            raise ProfilerAlreadyRunningError('Another profiler is running.')
        context.ensure_initialized()
        profiler_context = pywrap_tensorflow.TFE_NewProfilerContext()
        if context.default_execution_mode == context.EAGER_MODE:
            pywrap_tensorflow.TFE_ProfilerContextSetEagerContext(
                profiler_context,
                context.context()._handle)  # pylint: disable=protected-access
        _profiler = pywrap_tensorflow.TFE_NewProfiler(profiler_context)
        pywrap_tensorflow.TFE_DeleteProfilerContext(profiler_context)
        if not pywrap_tensorflow.TFE_ProfilerIsOk(_profiler):
            logging.warning(
                'Another profiler session is running which is probably '
                'created by profiler server. Please avoid using profiler '
                'server and profiler APIs at the same time.')
    def testMultipleGroups(self):
        context._reset_context()
        cpus = config.list_physical_devices('CPU')
        self.assertEqual(len(cpus), 1)
        config.set_logical_device_configuration(cpus[0], [
            context.LogicalDeviceConfiguration(),
            context.LogicalDeviceConfiguration(),
            context.LogicalDeviceConfiguration()
        ])
        context.ensure_initialized()
        num_elements = 4

        @def_function.function
        def run_all_reduce(group_size, group_key):
            instance_key = group_key
            input_value = [group_key for i in range(num_elements)]
            collectives = []
            for device_idx in range(group_size):
                with ops.device('/CPU:{}'.format(device_idx)):
                    input_tensor = constant_op.constant(input_value)
                    collectives.append(
                        collective_ops.all_reduce(input_tensor,
                                                  group_size,
                                                  group_key,
                                                  instance_key,
                                                  merge_op='Add',
                                                  final_op='Id'))
            return collectives

        def run_and_assert(group_size, group_key):
            for reduced_tensor in run_all_reduce(group_size, group_key):
                self.assertAllEqual(
                    [group_key * group_size for i in range(num_elements)],
                    reduced_tensor.numpy())

        run_and_assert(group_size=2, group_key=1)
        run_and_assert(group_size=3, group_key=2)
Example #39
0
    def testCollectiveGroupSizeMismatch(self):
        cpus = config.list_physical_devices('CPU')
        self.assertEqual(len(cpus), 1)
        config.set_logical_device_configuration(cpus[0], [
            context.LogicalDeviceConfiguration(),
            context.LogicalDeviceConfiguration()
        ])
        context.ensure_initialized()

        @def_function.function
        def run_all_reduce():
            group_key = 10
            instance_key = 20
            t0 = [1, 2, 3, 4]
            t1 = [5, 6, 7, 8]
            with ops.device('/CPU:0'):
                in0 = constant_op.constant(t0)
                c0 = collective_ops.all_reduce(in0,
                                               group_size=2,
                                               group_key=group_key,
                                               instance_key=instance_key,
                                               merge_op='Add',
                                               final_op='Id')
            with ops.device('/CPU:1'):
                in1 = constant_op.constant(t1)
                c1 = collective_ops.all_reduce(in1,
                                               group_size=3,
                                               group_key=group_key,
                                               instance_key=instance_key,
                                               merge_op='Add',
                                               final_op='Id')
            return c0, c1

        with self.assertRaisesRegex(errors.InternalError,
                                    'but that group has size'):
            run_all_reduce()
Example #40
0
  def __call__(self, *args, **kwds):
    """Calls the graph function."""
    context.ensure_initialized()
    if RUN_FUNCTIONS_EAGERLY:
      return self._python_function(*args, **kwds)
    if self._created_variables:
      # In this case we have created variables on the first call, so we run the
      # defunned version which is guaranteed to never create variables.
      return self._stateless_fn(*args, **kwds)  # pylint: disable=not-callable
    elif self._stateful_fn is not None:
      # In this case we have not created variables on the first call. So we can
      # run the first trace but we should fail if variables are created.
      results = self._stateful_fn(*args, **kwds)
      if self._created_variables:
        raise ValueError("Creating variables on a non-first call to a function"
                         " decorated with tf.function.")
      return results

    # This is the first call of __call__, so we have to initialize.
    initializer_map = {}
    self._initialize(args, kwds, add_initializers_to=initializer_map)
    if self._created_variables:
      try:
        # Attempt to initialize variables eagerly and without conds by lifting
        # out initialization graphs. This is the only initialization strategy
        # compatible with XLA at the moment.
        self._initialize_uninitialized_variables(initializer_map)
      except lift_to_graph.UnliftableError:
        pass  # Fall through to cond-based initialization.
      else:
        # Lifting succeeded, so variables are initialized and we can run the
        # stateless function.
        return self._stateless_fn(*args, **kwds)
    else:
      canon_args, canon_kwds = \
          self._stateful_fn._function_spec.canonicalize_function_inputs(  # pylint: disable=protected-access
              *args, **kwds)
      # If we did not create any variables the trace we have is good enough.
      return self._concrete_stateful_fn._filtered_call(canon_args, canon_kwds)  # pylint: disable=protected-access

    def fn_with_cond(*inner_args, **inner_kwds):
      """Conditionally runs initialization if it's needed."""
      condition = True
      for wr in self._created_variables:
        variable = wr()
        if variable is None:
          raise ValueError(
              "A tf.Variable created inside your tf.function has been"
              " garbage-collected. Your code needs to keep Python references"
              " to variables created inside `tf.function`s.\n"
              "\n"
              "A common way to raise this error is to create and return a"
              " variable only referenced inside your function:\n"
              "\n"
              "@tf.function\n"
              "def f():\n"
              "  v = tf.Variable(1.0)\n"
              "  return v\n"
              "\n"
              "v = f()  # Crashes with this error message!\n"
              "\n"
              "The reason this crashes is that @tf.function annotated"
              " function returns a **`tf.Tensor`** with the **value** of the"
              " variable when the function is called rather than the"
              " variable instance itself. As such there is no code holding a"
              " reference to the `v` created inside the function and Python"
              " garbage collects it.\n"
              "\n"
              "The simplest way to fix this issue is to create variables"
              " outside the function and capture them:\n"
              "\n"
              "v = tf.Variable(1.0)\n"
              "\n"
              "@tf.function\n"
              "def f():\n"
              "  return v\n"
              "\n"
              "f()  # <tf.Tensor: ... numpy=1.>\n"
              "v.assign_add(1.)\n"
              "f()  # <tf.Tensor: ... numpy=2.>")
        condition = math_ops.logical_and(
            condition, resource_variable_ops.var_is_initialized_op(
                variable.handle))
      # We want to call stateless_fn if possible because it avoids recomputing
      # potentially expensive initializers.
      return control_flow_ops.cond(
          condition,
          lambda: self._stateless_fn(*inner_args, **inner_kwds),
          functools.partial(self._concrete_stateful_fn._filtered_call,  # pylint: disable=protected-access
                            inner_args, inner_kwds))

    # We've created variables and are unable to lift the initialization graphs,
    # so we fall back to initializing with conds while running the function.
    canon_args, canon_kwds = \
        self._stateful_fn._function_spec.canonicalize_function_inputs(  # pylint: disable=protected-access
            *args, **kwds)
    return function_lib.defun(fn_with_cond)(*canon_args, **canon_kwds)
Example #41
0
  def __call__(self, *args, **kwds):
    """Calls the graph function."""
    context.ensure_initialized()
    if RUN_FUNCTIONS_EAGERLY:
      return self._python_function(*args, **kwds)
    if self._created_variables:
      # In this case we have created variables on the first call, so we run the
      # defunned version which is guaranteed to never create variables.
      return self._stateless_fn(*args, **kwds)  # pylint: disable=not-callable
    elif self._stateful_fn is not None:
      # In this case we have not created variables on the first call. So we can
      # run the first trace but we should fail if variables are created.
      results = self._stateful_fn(*args, **kwds)
      if self._created_variables:
        raise ValueError("Creating variables on a non-first call to a function"
                         " decorated with tf.function.")
      return results

    # This is the first call of __call__, so we have to initialize.
    initializer_map = {}
    self._initialize(args, kwds, add_initializers_to=initializer_map)
    if self._created_variables:
      try:
        # Attempt to initialize variables eagerly and without conds by lifting
        # out initialization graphs. This is the only initialization strategy
        # compatible with XLA at the moment.
        self._initialize_uninitialized_variables(initializer_map)
      except lift_to_graph.UnliftableError:
        pass  # Fall through to cond-based initialization.
      else:
        # Lifting succeeded, so variables are initialized and we can run the
        # stateless function.
        return self._stateless_fn(*args, **kwds)
    else:
      canon_args, canon_kwds = \
          self._stateful_fn._function_spec.canonicalize_function_inputs(  # pylint: disable=protected-access
              *args, **kwds)
      # If we did not create any variables the trace we have is good enough.
      return self._concrete_stateful_fn._filtered_call(canon_args, canon_kwds)  # pylint: disable=protected-access

    def fn_with_cond(*inner_args, **inner_kwds):
      """Conditionally runs initialization if it's needed."""
      condition = True
      for wr in self._created_variables:
        variable = wr()
        if variable is None:
          raise ValueError(
              "A tf.Variable created inside your tf.function has been"
              " garbage-collected. Your code needs to keep Python references"
              " to variables created inside `tf.function`s.\n"
              "\n"
              "A common way to raise this error is to create and return a"
              " variable only referenced inside your function:\n"
              "\n"
              "@tf.function\n"
              "def f():\n"
              "  v = tf.Variable(1.0)\n"
              "  return v\n"
              "\n"
              "v = f()  # Crashes with this error message!\n"
              "\n"
              "The reason this crashes is that @tf.function annotated"
              " function returns a **`tf.Tensor`** with the **value** of the"
              " variable when the function is called rather than the"
              " variable instance itself. As such there is no code holding a"
              " reference to the `v` created inside the function and Python"
              " garbage collects it.\n"
              "\n"
              "The simplest way to fix this issue is to create variables"
              " outside the function and capture them:\n"
              "\n"
              "v = tf.Variable(1.0)\n"
              "\n"
              "@tf.function\n"
              "def f():\n"
              "  return v\n"
              "\n"
              "f()  # <tf.Tensor: ... numpy=1.>\n"
              "v.assign_add(1.)\n"
              "f()  # <tf.Tensor: ... numpy=2.>")
        condition = math_ops.logical_and(
            condition, resource_variable_ops.var_is_initialized_op(
                variable.handle))
      # We want to call stateless_fn if possible because it avoids recomputing
      # potentially expensive initializers.
      return control_flow_ops.cond(
          condition,
          lambda: self._stateless_fn(*inner_args, **inner_kwds),
          functools.partial(self._concrete_stateful_fn._filtered_call,  # pylint: disable=protected-access
                            inner_args, inner_kwds))

    # We've created variables and are unable to lift the initialization graphs,
    # so we fall back to initializing with conds while running the function.
    canon_args, canon_kwds = \
        self._stateful_fn._function_spec.canonicalize_function_inputs(  # pylint: disable=protected-access
            *args, **kwds)
    return function_lib.defun(fn_with_cond)(*canon_args, **canon_kwds)
Example #42
0
 def setUp(self):
   super(SoftDevicePlacementTest, self).setUp()
   context._reset_context()
   context.ensure_initialized()
   config.set_soft_device_placement(enabled=True)
   context.context().log_device_placement = True
Example #43
0
    def __init__(self,
                 dist,
                 coord,
                 replica_id,
                 devices,
                 variable_creator_fn,
                 fn,
                 caching_scope,
                 args,
                 kwargs,
                 thread_local_callables=None):
        super(_MirroredReplicaThread, self).__init__()
        self.coord = coord
        self.distribution = dist
        self.devices = devices
        self.replica_id = replica_id
        self.replica_id_in_sync_group = (
            dist.extended._get_replica_id_in_sync_group(replica_id))  # pylint: disable=protected-access

        self.variable_creator_fn = variable_creator_fn
        # State needed to run and return the results of `fn`.
        self.main_fn = fn
        self.main_args = args
        self.main_kwargs = kwargs
        self.main_result = None
        self.done = False
        # State needed to run the next merge_call() (if any) requested via
        # ReplicaContext.
        self.merge_fn = None
        self.merge_args = None
        self.merge_kwargs = None
        self.merge_result = None
        self.captured_name_scope = None
        self.captured_var_scope = None
        try:
            self.caching_scope_entered = caching_scope.new_cache_scope_count
            self.caching_scope_exited = caching_scope.cache_scope_exited_count
        except AttributeError:
            self.caching_scope_entered = None
            self.caching_scope_exited = None

        # We use a thread.Event for the main thread to signal when this
        # thread should start running (`should_run`), and another for
        # this thread to transfer control back to the main thread
        # (`has_paused`, either when it gets to a
        # `get_replica_context().merge_call` or when `fn` returns). In
        # either case the event starts cleared, is signaled by calling
        # set(). The receiving thread waits for the signal by calling
        # wait() and then immediately clearing the event using clear().
        self.should_run = threading.Event()
        self.has_paused = threading.Event()
        # These fields have to do with inheriting various contexts from the
        # parent thread:
        context.ensure_initialized()
        ctx = context.context()
        self.in_eager = ctx.executing_eagerly()
        self.record_thread_local_summary_state()
        self.record_thread_local_eager_context_state()
        self.context_device_policy = (
            pywrap_tfe.TFE_ContextGetDevicePlacementPolicy(ctx._context_handle)
        )  # pylint: disable=protected-access
        self.graph = ops.get_default_graph()
        with ops.init_scope():
            self._init_in_eager = context.executing_eagerly()
            self._init_graph = ops.get_default_graph()
        self._variable_creator_stack = self.graph._variable_creator_stack[:]  # pylint: disable=protected-access
        self._var_scope = variable_scope.get_variable_scope()
        # Adding a "/" at end lets us re-enter this scope later.
        self._name_scope = self.graph.get_name_scope()
        if self._name_scope:
            self._name_scope += "/"
        if self.replica_id > 0:
            if not self._name_scope:
                self._name_scope = ""
            self._name_scope += "replica_%d/" % self.replica_id

        self._thread_local_callables = thread_local_callables
Example #44
0
 def setUp(self):
     context.ensure_initialized()
     super(PythonTensorConverterTest, self).setUp()
def _setup_context():
    context._reset_context()
    test_util.set_logical_devices_to_at_least('CPU', 4)
    context.ensure_initialized()
Example #46
0
 def _ctx(self):
     # N.B. This is needed to support calling py_func with GPU tensors,
     # which must be transferred to CPU if used in any of the NumPy APIs.
     context.ensure_initialized()
     return context.context()._handle  # pylint: disable=protected-access