示例#1
0
def open(device_id=None):
    """Initiate and return a NPU device handle"""
    if device_id is None:
        device_id = int(os.getenv("ASCEND_DEVICE_ID", '0'))

    with _npu_ctx_lock:
        if not isinstance(context.context(), _ContextWithDefaultDevice):
            ctx = _ContextWithDefaultDevice()
            ctx.ensure_initialized()
            context._set_context(ctx)
            _npu_device_instances.clear(
            )  # Global context has changed since last init npu

        if device_id in _npu_device_instances.keys():
            logging.info('Npu instance on device %s already created',
                         str(device_id))
            return _npu_device_instances.get(device_id)

        if len(_npu_device_instances):
            raise RuntimeError(
                'Failed create npu instance on device {} as existed instance on {}'
                ''.format(device_id, list(_npu_device_instances.keys())))

        global_kw_options = global_options().as_dict()
        workers_num = int(os.getenv('RANK_SIZE', '1'))
        if workers_num > 1:
            env_rank_table = os.getenv("RANK_TABLE_FILE")
            env_worker_id = os.getenv('RANK_ID')
            if not env_rank_table:
                raise RuntimeError(
                    'You must specify a rank table file by set env RANK_TABLE_FILE in distribution mode'
                )

            if not env_worker_id:
                raise RuntimeError(
                    'You must specify rank id by set env RANK_ID in distribution mode'
                )

            global_kw_options['_distribute.rank_table'] = env_rank_table
            global_kw_options['_distribute.rank_id'] = env_worker_id

        device_options = {}
        error_message = _npu_device_backends.Open(context.context()._handle,
                                                  NPU, device_id,
                                                  global_kw_options,
                                                  device_options)
        if error_message:
            raise RuntimeError("Failed open npu device %s : %s" %
                               (str(device_id), error_message))

        if workers_num > 1:
            from hccl.manage.api import get_rank_id
            worker_id = get_rank_id()
        else:
            worker_id = 0

        _npu_device_instances[device_id] = NpuDeviceHandle(
            context.context(), device_id, device_options, workers_num,
            worker_id)
        return _npu_device_instances[device_id]
示例#2
0
 def testSilentCopy(self):
     # Temporarily replace the context
     # pylint: disable=protected-access
     old_context = context.context()
     context._set_context(context.Context())
     try:
         config.set_device_policy('silent')
         cpu_tensor = constant_op.constant(1.0)
         gpu_tensor = cpu_tensor.gpu()
         self.assertAllEqual(cpu_tensor + gpu_tensor, 2.0)
     finally:
         context._set_context(old_context)
示例#3
0
    def testCrossContextTensorCache(self):
        old_context = context.context()
        old_x = constant_op.constant(9.5)
        context._set_context(context.Context())

        try:
            new_x = constant_op.constant(9.5)
            self.assertEqual(new_x.numpy(), 9.5)
        finally:
            context._set_context(old_context)

        self.assertEqual(old_x.numpy(), 9.5)
示例#4
0
 def testSoftPlacement(self):
     # Temporarily replace the context
     # pylint: disable=protected-access
     old_context = context.context()
     context._set_context(context.Context())
     try:
         config.set_device_policy('silent')
         config.set_soft_device_placement(True)
         cpu_tensor = constant_op.constant(1.0)
         result = cpu_tensor + cpu_tensor
         self.assertEqual(result.device,
                          '/job:localhost/replica:0/task:0/device:GPU:0')
     finally:
         context._set_context(old_context)
示例#5
0
 def testSoftPlacement(self):
     # Temporarily replace the context
     # pylint: disable=protected-access
     old_context = context.context()
     context._set_context(context.Context())
     try:
         config.set_device_policy('silent')
         config.set_soft_device_placement(True)
         # Avoid the TensorHandle cache hit.
         # TODO(b/169790439): include Context to the TensorHandle cache.
         cpu_tensor = constant_op.constant(1.1)
         result = cpu_tensor + cpu_tensor
         self.assertEqual(result.device,
                          '/job:localhost/replica:0/task:0/device:GPU:0')
     finally:
         context._set_context(old_context)
示例#6
0
  def testContextIsDestroyedAfterTensors(self):
    # Create a new context
    new_context = context.Context()
    weak_c = weakref.ref(new_context)
    new_context.ensure_initialized()

    # Create a tensor with the new context as default.
    # Make sure to restore the original context.
    original_context = context.context()
    try:
      context._set_context(new_context)
      # Use a 2D tensor so that it is not cached.
      tensor1 = constant_op.constant([[3.]])
      # Produce a tensor as an operation output. This uses a different code path
      # from tensors created from Python.
      tensor2 = tensor1 * tensor1
      context._set_context(original_context)
    except:
      context._set_context(original_context)
      raise

    # Deleting our context reference should not delete the underlying object.
    del new_context
    self.assertIsNot(weak_c(), None)

    # Deleting the first tensor should not delete the context since there is
    # another tensor.
    del tensor1
    self.assertIsNot(weak_c(), None)

    # Deleting the last tensor should result in deleting its context.
    del tensor2
    self.assertIs(weak_c(), None)