def _experimental_assign_to_logical_device(self, tensor, logical_device_id): """See `DistributionStrategy.experimental_assign_to_logical_device`.""" num_logical_devices_per_replica = self._tpu_devices.shape[1] if (logical_device_id < 0 or logical_device_id >= num_logical_devices_per_replica): raise ValueError("`logical_core_id` to assign must be lower then total " "number of logical devices per replica. Received " "logical device id {} but there are only total of {} " "logical devices in replica.".format( logical_device_id, num_logical_devices_per_replica)) return xla_sharding.assign_device(tensor, logical_device_id)
def _tag_sharding_attribute_for_dequeued_tensor(tensor, dims): """Tags appropriate XLA sharding attribute to the dequeued tensor. Args: tensor: The dequeued tensor on TPU. dims: A list of integer describes how the tensor is partitioned. Returns: The same tensor with the xla_sharding attribute. """ if dims is None: return xla_sharding.replicate(tensor) elif np.prod(dims) == 1: return xla_sharding.assign_device(tensor, 0) else: tile_assignment = np.arange(np.prod(dims)).reshape(dims) return xla_sharding.tile(tensor=tensor, tile_assignment=tile_assignment)
def _tag_sharding_attribute_for_dequeued_tensor(tensor, dims): """Tags appropriate XLA sharding attribute to the dequeued tensor. Args: tensor: The dequeued tensor on TPU. dims: A list of integer describes how the tensor is partitioned. Returns: The same tensor with the xla_sharding attribute. """ if dims is None: return xla_sharding.replicate(tensor) elif np.prod(dims) == 1: return xla_sharding.assign_device(tensor, 0) else: tile_assignment = np.arange(np.prod(dims)).reshape(dims) return xla_sharding.tile(tensor=tensor, tile_assignment=tile_assignment)
def _tag_sharding_attribute_for_dequeued_tensor(self, tensor, dims): """Tags appropriate XLA sharding attribute to the dequeued tensor. Args: tensor: The dequeued tensor on TPU. dims: A list of integer describes how the tensor is partitioned. Returns: The same tensor with the xla_sharding attribute. """ if dims is None: return xla_sharding.replicate(tensor) elif np.prod(dims) == 1: return xla_sharding.assign_device(tensor, 0) else: tile_shape = np.array(tensor.shape.as_list()) // dims tile_assignment = np.arange(np.prod(dims)).reshape(dims) return xla_sharding.tile( tensor=tensor, tile_shape=xla_shape.CreateShapeFromDtypeAndTuple( dtype=np.dtype(tensor.dtype.as_numpy_dtype), shape_tuple=tile_shape), tile_assignment=tile_assignment)
def _tag_sharding_attribute_for_dequeued_tensor(self, tensor, dims): """Tags appropriate XLA sharding attribute to the dequeued tensor. Args: tensor: The dequeued tensor on TPU. dims: A list of integer describes how the tensor is partitioned. Returns: The same tensor with the xla_sharding attribute. """ if dims is None: return xla_sharding.replicate(tensor) elif np.prod(dims) == 1: return xla_sharding.assign_device(tensor, 0) else: tile_shape = np.array(tensor.shape.as_list()) // dims tile_assignment = np.arange(np.prod(dims)).reshape(dims) return xla_sharding.tile( tensor=tensor, tile_shape=xla_shape.CreateShapeFromDtypeAndTuple( dtype=np.dtype(tensor.dtype.as_numpy_dtype), shape_tuple=tile_shape), tile_assignment=tile_assignment)