Beispiel #1
0
    def setUp(self):
        super(_VirtualDeviceTestCase, self).setUp()
        ctx = context.context()
        if ctx.list_physical_devices("TPU"):
            self.device_type = "TPU"
            tpu_strategy_util.initialize_tpu_system()
        elif ctx.list_physical_devices("GPU"):
            self.device_type = "GPU"
            gpus = ctx.list_physical_devices(self.device_type)
            ctx.set_logical_device_configuration(gpus[0], [
                context.LogicalDeviceConfiguration(memory_limit=100),
                context.LogicalDeviceConfiguration(memory_limit=100),
            ])
        else:
            self.device_type = "CPU"
            cpus = ctx.list_physical_devices("CPU")
            ctx.set_logical_device_configuration(cpus[0], [
                context.LogicalDeviceConfiguration(),
                context.LogicalDeviceConfiguration(),
            ])

        self.device = parallel_device.ParallelDevice(components=[
            "/job:localhost/device:{}:0".format(self.device_type),
            self.device_type + ":1"
        ])
        self.assertIn(self.device_type + ":0", self.device.components[0])
        self.assertIn(self.device_type + ":1", self.device.components[1])
Beispiel #2
0
  def testMultipleGroups(self):
    context._reset_context()
    cpus = config.list_physical_devices('CPU')
    self.assertEqual(len(cpus), 1)
    config.set_logical_device_configuration(cpus[0], [
        context.LogicalDeviceConfiguration(),
        context.LogicalDeviceConfiguration(),
        context.LogicalDeviceConfiguration()
    ])
    context.ensure_initialized()
    num_elements = 4

    @def_function.function
    def run_all_reduce(group_size, group_key):
      instance_key = group_key
      input_value = [group_key for i in range(num_elements)]
      collectives = []
      for device_idx in range(group_size):
        with ops.device('/CPU:{}'.format(device_idx)):
          input_tensor = constant_op.constant(input_value)
          collectives.append(collective_ops.all_reduce(
              input_tensor, group_size, group_key, instance_key, merge_op='Add',
              final_op='Id'))
      return collectives

    def run_and_assert(group_size, group_key):
      for reduced_tensor in run_all_reduce(group_size, group_key):
        self.assertAllEqual(
            [group_key * group_size for i in range(num_elements)],
            reduced_tensor.numpy())

    run_and_assert(group_size=2, group_key=1)
    run_and_assert(group_size=3, group_key=2)
Beispiel #3
0
  def test_save_variable_devices(self, save_devices, meta_graph_only):
    context._reset_context()
    cpus = context.context().list_physical_devices("CPU")
    if len(cpus) == 1:
      context.context().set_logical_device_configuration(
          cpus[0], [
              context.LogicalDeviceConfiguration(),
              context.LogicalDeviceConfiguration()
          ])
    context.ensure_initialized()

    root = tracking.AutoTrackable()
    with ops.device("CPU:0"):
      root.v0 = variables.Variable(1., name="v0")
    with ops.device("CPU:1"):
      root.v1 = variables.Variable(1., name="v1")

    options = save_options.SaveOptions(
        experimental_variable_policy=save_devices)
    file_name = os.path.join(self.get_temp_dir(), "saved_model")
    if meta_graph_only:
      save.export_meta_graph(obj=root, filename=file_name, options=options)
    else:
      save.save(obj=root, export_dir=file_name, options=options)

    meta = None
    if meta_graph_only:
      meta = meta_graph.read_meta_graph_file(file_name)
    else:
      meta = loader_impl.parse_saved_model(file_name).meta_graphs[0]

    # Check devices in meta graph nodes.
    graph_def = meta.graph_def
    v0 = next((n for n in graph_def.node if n.name == "v0"), None)
    v1 = next((n for n in graph_def.node if n.name == "v1"), None)
    self.assertIsNotNone(v0)
    self.assertIsNotNone(v1)
    if save_devices == save_options.VariablePolicy.SAVE_VARIABLE_DEVICES:
      self.assertIn("CPU:0", v0.device)
      self.assertIn("CPU:1", v1.device)
    else:
      self.assertEmpty(v0.device)
      self.assertEmpty(v1.device)

    # Check devices in object graph nodes.
    object_graph_def = meta.object_graph_def
    v0 = next((n.variable
               for n in object_graph_def.nodes
               if n.HasField("variable") and n.variable.name == "v0"), None)
    v1 = next((n.variable
               for n in object_graph_def.nodes
               if n.HasField("variable") and n.variable.name == "v1"), None)
    self.assertIsNotNone(v0)
    self.assertIsNotNone(v1)
    if save_devices == save_options.VariablePolicy.SAVE_VARIABLE_DEVICES:
      self.assertIn("CPU:0", v0.device)
      self.assertIn("CPU:1", v1.device)
    else:
      self.assertEmpty(v0.device)
      self.assertEmpty(v1.device)
def _mimic_two_cpus():
    cpus = config.list_physical_devices("CPU")

    config.set_logical_device_configuration(cpus[0], [
        context.LogicalDeviceConfiguration(),
        context.LogicalDeviceConfiguration(),
    ])
  def testCollectiveGroupSizeMismatch(self):
    cpus = config.list_physical_devices('CPU')
    self.assertEqual(len(cpus), 1)
    config.set_logical_device_configuration(cpus[0], [
        context.LogicalDeviceConfiguration(),
        context.LogicalDeviceConfiguration()
    ])
    context.ensure_initialized()

    @def_function.function
    def run_all_reduce():
      group_key = 10
      instance_key = 20
      t0 = [1, 2, 3, 4]
      t1 = [5, 6, 7, 8]
      with ops.device('/CPU:0'):
        in0 = constant_op.constant(t0)
        c0 = collective_ops.all_reduce(
            in0, group_size=2, group_key=group_key, instance_key=instance_key,
            merge_op='Add', final_op='Id')
      with ops.device('/CPU:1'):
        in1 = constant_op.constant(t1)
        c1 = collective_ops.all_reduce(
            in1, group_size=3, group_key=group_key, instance_key=instance_key,
            merge_op='Add', final_op='Id')
      return c0, c1

    with self.assertRaisesRegexp(errors.InternalError,
                                 'but that group has size'):
      run_all_reduce()
Beispiel #6
0
 def setUp(self):
     super(StatefulRandomOpsTest, self).setUp()
     physical_devices = config.list_physical_devices("CPU")
     config.set_logical_device_configuration(physical_devices[0], [
         context.LogicalDeviceConfiguration(),
         context.LogicalDeviceConfiguration()
     ])
def configure_virtual_cpus():
    cpus = config.list_physical_devices('CPU')
    # Set 2 virtual CPUs
    config.set_logical_device_configuration(cpus[0], [
        context.LogicalDeviceConfiguration(),
        context.LogicalDeviceConfiguration()
    ])
Beispiel #8
0
  def test_expand_distributed_variables(self, expand_strategy, policy):
    # 1. Create a context with both CPU:0 and CPU:1.
    context._reset_context()
    cpus = context.context().list_physical_devices("CPU")
    if len(cpus) == 1:
      context.context().set_logical_device_configuration(
          cpus[0], [
              context.LogicalDeviceConfiguration(),
              context.LogicalDeviceConfiguration()
          ])
    context.ensure_initialized()

    # 2. Create and save a model under a mirrored strategy.
    file_name = os.path.join(self.get_temp_dir(), "saved_model.pb")
    strategy = mirrored_strategy.MirroredStrategy(["CPU:0", "CPU:1"])
    strategy.extended._use_var_policy = policy
    with strategy.scope():
      root = tracking.AutoTrackable()
      root.v = variables.Variable([1., 1.], name="v")

      @def_function.function(input_signature=[])
      def f():
        root.v.assign([2., 2.])

      root.f = f

      save.export_meta_graph(
          obj=root,
          filename=file_name,
          options=save_options.SaveOptions(
              experimental_variable_policy=expand_strategy))

    # 3. Read the output file and test behavior.
    meta_graph_def = meta_graph.read_meta_graph_file(file_name)
    object_graph = meta_graph_def.object_graph_def
    graph_def = meta_graph_def.graph_def
    v = next((n.variable
              for n in object_graph.nodes
              if n.HasField("variable") and n.variable.name == "v"), None)
    saved_function = next((f for f in graph_def.library.function
                           if "inference_f_" in f.signature.name), None)
    self.assertIsNotNone(saved_function)
    if (expand_strategy ==
        save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES):
      # experimental_save_variable_devices should have been automatically set.
      self.assertIn("CPU:0", v.device)
      components = v.experimental_distributed_variable_components
      self.assertLen(components, 2)
      v0 = next((x for x in components if x.name == "v"), None)
      v1 = next((x for x in components if x.name == "v/replica_1"), None)
      self.assertIsNotNone(v0)
      self.assertIsNotNone(v1)
      self.assertIn("CPU:0", v0.device)
      self.assertIn("CPU:1", v1.device)
      self.assertLen(saved_function.signature.input_arg, 2)
    else:
      self.assertEmpty(v.device)
      self.assertEmpty(v.experimental_distributed_variable_components)
      self.assertLen(saved_function.signature.input_arg, 1)
Beispiel #9
0
 def setUp(self):
   super(LayerCorrectnessTest, self).setUp()
   # Set two virtual CPUs to test MirroredStrategy with multiple devices
   cpus = config_module.list_physical_devices('CPU')
   config_module.set_logical_device_configuration(cpus[0], [
       context.LogicalDeviceConfiguration(),
       context.LogicalDeviceConfiguration(),
   ])
Beispiel #10
0
 def setUp(self):
     super(RpcOpsTest, self).setUp()
     cpus = config.list_physical_devices("CPU")
     # Set 2 virtual CPUs
     config.set_logical_device_configuration(cpus[0], [
         context.LogicalDeviceConfiguration(),
         context.LogicalDeviceConfiguration()
     ])
Beispiel #11
0
 def setUp(self):
     super(PackedDistributedVariableTest, self).setUp()
     cpus = config.list_physical_devices('CPU')
     # Set 2 virtual CPUs
     config.set_logical_device_configuration(cpus[0], [
         context.LogicalDeviceConfiguration(),
         context.LogicalDeviceConfiguration(),
     ])
Beispiel #12
0
 def testCollectiveGroupSizeMismatch(self):
     cpus = config.list_physical_devices('CPU')
     self.assertEqual(len(cpus), 1)
     config.set_logical_device_configuration(cpus[0], [
         context.LogicalDeviceConfiguration(),
         context.LogicalDeviceConfiguration()
     ])
     context.ensure_initialized()
Beispiel #13
0
 def _setup_context(self, num_cpus=2):
     context._reset_context()
     cpus = config.list_physical_devices('CPU')
     self.assertEqual(len(cpus), 1)
     config.set_logical_device_configuration(cpus[0], [
         context.LogicalDeviceConfiguration(),
         context.LogicalDeviceConfiguration()
     ])
     context.ensure_initialized()
Beispiel #14
0
 def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
     super(TestMultiGPUModel, self).__init__(methodName)
     gpu_devices = config.list_physical_devices('GPU')
     if len(gpu_devices) == 1:
         # A GPU is available, simulate 2 instead.
         config.set_logical_device_configuration(gpu_devices[0], [
             context.LogicalDeviceConfiguration(500),
             context.LogicalDeviceConfiguration(500)
         ])
Beispiel #15
0
    def testAbortInstanceParamsResolution(self):
        cpus = config.list_physical_devices('CPU')
        config.set_logical_device_configuration(cpus[0], [
            context.LogicalDeviceConfiguration(),
            context.LogicalDeviceConfiguration()
        ])
        group_size = 2
        group_key = 100
        instance_key = 100
        in_tensor = constant_op.constant(1.)

        def collective_fn():
            for device in ['CPU:0', 'CPU:1']:
                with ops.device(device):
                    collective_ops.all_reduce(in_tensor,
                                              group_size,
                                              group_key,
                                              instance_key,
                                              'Add',
                                              'Id',
                                              communication_hint='ring')

        # First perform a normal all-reduce to complete the group resolution.
        def_function.function(collective_fn)()

        def abort_fn():
            time.sleep(2)
            context.context().abort_collective_ops(errors.UNAVAILABLE,
                                                   'peer down')

        t = threading.Thread(target=abort_fn)
        t.start()

        # Use a different instance key to trigger another instance resolution.
        instance_key = 101
        with self.assertRaisesRegex(errors.UnavailableError, 'peer down'):
            # This hangs on params resolution since we're only launching one
            # collective for a group size of 2.
            collective_ops.all_reduce(in_tensor, group_size, group_key,
                                      instance_key, 'Add', 'Id')

        # After abortion, subsequent collectives should fail immediately.
        with self.assertRaisesRegex(errors.UnavailableError, 'peer down'):
            collective_ops.all_reduce(in_tensor, group_size, group_key,
                                      instance_key, 'Add', 'Id')

        # Reset the context in order to reset the collective executor.
        context._reset_context()  # pylint: disable=protected-access
        t.join()

        # After reset non-NCCL collectives should work.
        cpus = config.list_physical_devices('CPU')
        config.set_logical_device_configuration(cpus[0], [
            context.LogicalDeviceConfiguration(),
            context.LogicalDeviceConfiguration()
        ])
        def_function.function(collective_fn)()
Beispiel #16
0
def set_up_virtual_devices():
    global _virtual_devices_ready
    if _virtual_devices_ready:
        return
    physical_devices = config.list_physical_devices('CPU')
    config.set_logical_device_configuration(physical_devices[0], [
        context.LogicalDeviceConfiguration(),
        context.LogicalDeviceConfiguration()
    ])
    _virtual_devices_ready = True
 def setUp(self):
     super(FunctionGradientsTest, self).setUp()
     cpus = config.list_physical_devices('CPU')
     # Set 4 virtual CPUs
     config.set_logical_device_configuration(cpus[0], [
         context.LogicalDeviceConfiguration(),
         context.LogicalDeviceConfiguration(),
         context.LogicalDeviceConfiguration(),
         context.LogicalDeviceConfiguration()
     ])
 def _ensure_context_initialized(self):
     gpus = config.list_physical_devices('GPU')
     if len(gpus) < 1:
         self.skipTest('Expected at least 1 GPU but found {} GPUs'.format(
             len(gpus)))
     config.set_logical_device_configuration(gpus[0], [
         context.LogicalDeviceConfiguration(1024),
         context.LogicalDeviceConfiguration(1024)
     ])
     context.ensure_initialized()
def _setup_context():
    context._reset_context()
    cpus = config.list_physical_devices('CPU')
    config.set_logical_device_configuration(cpus[0], [
        context.LogicalDeviceConfiguration(),
        context.LogicalDeviceConfiguration(),
        context.LogicalDeviceConfiguration(),
        context.LogicalDeviceConfiguration()
    ])
    context.ensure_initialized()
Beispiel #20
0
    def testExecutionAfterTimeoutV2(self):
        timeout = 1.5
        cpus = config.list_physical_devices('CPU')
        self.assertEqual(len(cpus), 1)
        config.set_logical_device_configuration(cpus[0], [
            context.LogicalDeviceConfiguration(),
            context.LogicalDeviceConfiguration()
        ])
        context.ensure_initialized()

        group_key = 20
        instance_key = 30
        input_data = constant_op.constant([1, 2, 3, 4])

        @def_function.function
        def run_all_reduce():
            for device in ['CPU:0', 'CPU:1']:
                with ops.device(device):
                    collective_ops.all_reduce(input_data,
                                              group_size=2,
                                              group_key=group_key,
                                              instance_key=instance_key,
                                              merge_op='Add',
                                              final_op='Id',
                                              timeout=timeout)

        # Run a normal all-reduce to complete param resolution.
        run_all_reduce()

        with self.assertRaisesRegex(
                errors.DeadlineExceededError,
                'Collective has timed out during execution'):
            with ops.device('CPU:0'):
                collective_ops.all_reduce(input_data,
                                          group_size=2,
                                          group_key=group_key,
                                          instance_key=instance_key,
                                          merge_op='Add',
                                          final_op='Id',
                                          timeout=timeout)

        # We launch the second device after the first device times out. This is to
        # simulate the situation when other workers are slow and the timeout is
        # short. It should error immediately.
        with self.assertRaisesRegex(
                errors.DeadlineExceededError,
                'Collective has timed out during execution'):
            with ops.device('CPU:1'):
                # No timeout.
                collective_ops.all_reduce(input_data,
                                          group_size=2,
                                          group_key=group_key,
                                          merge_op='Add',
                                          final_op='Id',
                                          instance_key=instance_key)
 def setUp(self):
   super(SaverTest, self).setUp()
   cpus = config.list_physical_devices("CPU")
   # Set 3 virtual CPUs
   config.set_logical_device_configuration(cpus[0], [
       context.LogicalDeviceConfiguration(),
       context.LogicalDeviceConfiguration(),
       context.LogicalDeviceConfiguration()
   ])
   self.local_options = checkpoint_options.CheckpointOptions(
       experimental_io_device=LOCALHOST)
Beispiel #22
0
    def testAbortRing(self):
        cpus = config.list_physical_devices('CPU')
        config.set_logical_device_configuration(cpus[0], [
            context.LogicalDeviceConfiguration(),
            context.LogicalDeviceConfiguration()
        ])
        group_size = 2
        group_key = 100
        instance_key = 100
        in_tensor = constant_op.constant(1.)

        # First perform a normal collective to finish resolution.
        def collective_fn():
            for device in ['CPU:0', 'CPU:1']:
                with ops.device(device):
                    collective_ops.all_reduce(in_tensor,
                                              group_size,
                                              group_key,
                                              instance_key,
                                              'Add',
                                              'Id',
                                              communication_hint='ring')

        def_function.function(collective_fn)()

        # Launch a collective that hangs, and abort the collective executor after
        # the launch.
        def abort_fn():
            time.sleep(2)
            context.context().abort_collective_ops(errors.UNAVAILABLE,
                                                   'peer down')

        t = threading.Thread(target=abort_fn)
        t.start()

        with self.assertRaisesRegex(errors.UnavailableError, 'peer down'):
            collective_ops.all_reduce(in_tensor, group_size, group_key,
                                      instance_key, 'Add', 'Id')

        # After abortion, subsequent collectives should fail immediately.
        with self.assertRaisesRegex(errors.UnavailableError, 'peer down'):
            collective_ops.all_reduce(in_tensor, group_size, group_key,
                                      instance_key, 'Add', 'Id')

        # Reset the context in order to reset the collective executor.
        t.join()
        context._reset_context()  # pylint: disable=protected-access
        # After reset non-NCCL collectives should work.
        cpus = config.list_physical_devices('CPU')
        config.set_logical_device_configuration(cpus[0], [
            context.LogicalDeviceConfiguration(),
            context.LogicalDeviceConfiguration()
        ])
        def_function.function(collective_fn)()
 def __init__(self, methodName="runTest"):  # pylint: disable=invalid-name
     super(LocalReplicateTest, self).__init__(methodName)
     cpus = config.list_physical_devices("CPU")
     # Set 3 virtual CPUs
     config.set_logical_device_configuration(cpus[0], [
         context.LogicalDeviceConfiguration(),
         context.LogicalDeviceConfiguration(),
         context.LogicalDeviceConfiguration()
     ])
     self._device0 = "/device:CPU:0"
     self._device1 = "/device:CPU:1"
     self._device2 = "/device:CPU:2"
 def __init__(self, methodName='runTest'):  # pylint: disable=invalid-name
   super(TestMultiGPUModel, self).__init__(methodName)
   gpu_devices = config.list_physical_devices('GPU')
   xla_gpu_devices = config.list_physical_devices('XLA_GPU')
   # NOTE: XLA devices don't support the set_logical_device_configuration
   # codepaths.
   if len(gpu_devices) == 1 and not xla_gpu_devices:
     # A GPU is available, simulate 2 instead.
     config.set_logical_device_configuration(gpu_devices[0], [
         context.LogicalDeviceConfiguration(500),
         context.LogicalDeviceConfiguration(500)
     ])
Beispiel #25
0
def _mimic_two_cpus():
    try:
        cpus = config.list_physical_devices("CPU")
    except errors_impl.NotFoundError:
        # Testing device not available. Skip the test.
        return False

    config.set_logical_device_configuration(cpus[0], [
        context.LogicalDeviceConfiguration(),
        context.LogicalDeviceConfiguration(),
    ])
    return True
Beispiel #26
0
 def setUp(self):
   super(HardDevicePlacementTest, self).setUp()
   context._reset_context()
   config.set_soft_device_placement(enabled=False)
   context.context().log_device_placement = True
   cpus = context.context().list_physical_devices('CPU')
   # Set 2 virtual CPUs
   context.context().set_logical_device_configuration(cpus[0], [
       context.LogicalDeviceConfiguration(),
       context.LogicalDeviceConfiguration()
   ])
   self.assertEqual(config.get_soft_device_placement(), False)
   self.assertEqual(context.context().soft_device_placement, False)
  def setUp(self):
    super(_VirtualDeviceTestCase, self).setUp()
    cpus = context.context().list_physical_devices("CPU")
    # Set 4 virtual CPUs
    context.context().set_logical_device_configuration(cpus[0], [
        context.LogicalDeviceConfiguration(),
        context.LogicalDeviceConfiguration(),
        context.LogicalDeviceConfiguration(),
        context.LogicalDeviceConfiguration()
    ])

    self.device = parallel_device.ParallelDevice(
        components=["/job:localhost/device:CPU:0", "CPU:1"])
    self.assertIn("CPU:0", self.device.components[0])
    self.assertIn("CPU:1", self.device.components[1])
Beispiel #28
0
    def testGpuInvalidConfig(self):
        gpus = config.list_physical_devices('GPU')
        self.assertNotEqual(len(gpus), 0)

        if len(gpus) > 1:
            # Assert if other GPUs were not configured
            config.set_memory_growth(gpus[0], True)
            with self.assertRaisesRegex(ValueError, 'cannot differ'):
                c = context.context().config

            # If we limit visibility to GPU 0, growth is fine
            config.set_visible_devices(gpus[0], 'GPU')
            c = context.context().config
            self.assertTrue(c.gpu_options.allow_growth)

            # Default setting for second GPU is False and works if we set visibility
            config.set_visible_devices(gpus[1], 'GPU')
            c = context.context().config
            self.assertFalse(c.gpu_options.allow_growth)

            # Growth now fails because all the GPUs are visible and not the same
            config.set_visible_devices(gpus, 'GPU')
            with self.assertRaisesRegex(ValueError, 'cannot differ'):
                c = context.context().config

        for gpu in gpus:
            config.set_memory_growth(gpu, True)

        c = context.context().config
        self.assertTrue(c.gpu_options.allow_growth)

        with self.assertRaisesRegex(ValueError, 'memory limit'):
            config.set_logical_device_configuration(gpus[-1], [
                context.LogicalDeviceConfiguration(),
                context.LogicalDeviceConfiguration()
            ])

        self.assertIsNone(config.get_logical_device_configuration(gpus[-1]))
        config.set_logical_device_configuration(gpus[-1], [
            context.LogicalDeviceConfiguration(memory_limit=10),
            context.LogicalDeviceConfiguration(memory_limit=10)
        ])

        c = context.context().config
        self.assertFalse(c.gpu_options.allow_growth)

        with self.assertRaisesRegex(ValueError, 'virtual devices'):
            config.set_memory_growth(gpus[-1], False)
    def testKeepLogicalDevice(self):
        gpus = tf_config.list_physical_devices('GPU')
        if len(gpus) > 1:
            self.skipTest(
                'Skip logical device test on multi GPUs, since partial GPU '
                'virtualization is not permitted.')
        # Cannot change logical device after the context initialization.
        context._reset_context()  # pylint: disable=protected-access
        cluster_spec = multi_worker_test_base.create_cluster_spec(
            has_chief=False, num_workers=1)
        resolver = cluster_resolver_lib.SimpleClusterResolver(
            cluster_spec=multi_worker_util.normalize_cluster_spec(
                cluster_spec),
            task_type='worker',
            task_id=0)

        logical_gpus = len(gpus) * 2
        for i, device in enumerate(gpus):
            n = (i +
                 1) * logical_gpus // len(gpus) - i * logical_gpus // len(gpus)
            assert n > 0  # guaranteed if count >= len(devices)
            configs = []
            for ordinal in range(n):
                config = context.LogicalDeviceConfiguration(
                    memory_limit=64, experimental_device_ordinal=ordinal)
                configs.append(config)

            tf_config.set_logical_device_configuration(device, configs)

        collective_all_reduce_strategy.CollectiveAllReduceStrategy(
            cluster_resolver=resolver)
        # Since we create two logical GPUs out of the last GPU, there should be one
        # more logical GPUs than physical GPUs.
        self.assertLen(tf_config.list_logical_devices('GPU'), logical_gpus)
        context._reset_context()  # pylint: disable=protected-access
Beispiel #30
0
def set_up_gpu_memory_limit(memory_limit_mb: int) -> None:
    gpus = framework_config.list_physical_devices("GPU")
    virtual_device_config = context.LogicalDeviceConfiguration(
        memory_limit=memory_limit_mb)
    for gpu in gpus:
        framework_config.set_logical_device_configuration(
            gpu, [virtual_device_config])