예제 #1
0
    def setUp(self):
        super(_VirtualDeviceTestCase, self).setUp()
        ctx = context.context()
        if ctx.list_physical_devices("TPU"):
            self.device_type = "TPU"
            tpu_strategy_util.initialize_tpu_system()
        elif ctx.list_physical_devices("GPU"):
            self.device_type = "GPU"
            gpus = ctx.list_physical_devices(self.device_type)
            ctx.set_logical_device_configuration(gpus[0], [
                context.LogicalDeviceConfiguration(memory_limit=100),
                context.LogicalDeviceConfiguration(memory_limit=100),
            ])
        else:
            self.device_type = "CPU"
            cpus = ctx.list_physical_devices("CPU")
            ctx.set_logical_device_configuration(cpus[0], [
                context.LogicalDeviceConfiguration(),
                context.LogicalDeviceConfiguration(),
            ])

        self.device = parallel_device.ParallelDevice(components=[
            "/job:localhost/device:{}:0".format(self.device_type),
            self.device_type + ":1"
        ])
        self.assertIn(self.device_type + ":0", self.device.components[0])
        self.assertIn(self.device_type + ":1", self.device.components[1])
예제 #2
0
 def test_one_replica_eager_control_flow(self):
     device = parallel_device.ParallelDevice(components=[
         "/job:localhost/device:{}:0".format(self.device_type),
     ])
     x = constant_op.constant([2, 3, 4])
     with device:
         x = device.pack([x])
         if math_ops.reduce_any(math_ops.equal(x, constant_op.constant(4))):
             y = constant_op.constant(1)
         else:
             y = constant_op.constant(2)
     self.assertAllEqual([1], device.unpack(y))
예제 #3
0
  def setUp(self):
    super(_VirtualDeviceTestCase, self).setUp()
    cpus = context.context().list_physical_devices("CPU")
    # Set 4 virtual CPUs
    context.context().set_logical_device_configuration(cpus[0], [
        context.LogicalDeviceConfiguration(),
        context.LogicalDeviceConfiguration(),
        context.LogicalDeviceConfiguration(),
        context.LogicalDeviceConfiguration()
    ])

    self.device = parallel_device.ParallelDevice(
        components=["/job:localhost/device:CPU:0", "CPU:1"])
    self.assertIn("CPU:0", self.device.components[0])
    self.assertIn("CPU:1", self.device.components[1])
  def setUp(self):
    super(_VirtualDeviceTestCase, self).setUp()
    cpus = context.context().list_physical_devices("CPU")
    # Set 4 virtual CPUs
    context.context().set_logical_device_configuration(cpus[0], [
        context.LogicalDeviceConfiguration(),
        context.LogicalDeviceConfiguration(),
        context.LogicalDeviceConfiguration(),
        context.LogicalDeviceConfiguration()
    ])

    # TODO(allenl): Make CPU:0 and CPU:1 work (right now "CPU:1" soft-places
    # onto CPU:0, which seems wrong).
    components = [
        "/job:localhost/replica:0/task:0/device:CPU:0",
        "/job:localhost/replica:0/task:0/device:CPU:1"
    ]
    self.device = parallel_device.ParallelDevice(components)