def copy_helper(tensor): tensor_src = array_ops.identity(tensor) tensor_src = xla_sharding.split(tensor, 2, 3) sharding_src = xla_sharding.get_tensor_sharding(tensor_src) shape_src = xla_sharding.get_sharding_tile_shape(sharding_src) self.assertEqual([1, 1, 3], shape_src) tensor_dest = array_ops.identity(tensor) self.assertIsNone(xla_sharding.get_tensor_sharding(tensor_dest)) xla_sharding.copy_sharding(tensor_src, tensor_dest) sharding_dest = xla_sharding.get_tensor_sharding(tensor_dest) shape_dest = xla_sharding.get_sharding_tile_shape(sharding_dest) self.assertEqual([1, 1, 3], shape_dest) return tensor_dest
def _create_slot_var(primary, val, scope, validate_shape, shape, dtype, *, copy_xla_sharding=False): """Helper function for creating a slot variable.""" # TODO(lukaszkaiser): Consider allowing partitioners to be set in the current # scope. current_partitioner = variable_scope.get_variable_scope().partitioner variable_scope.get_variable_scope().set_partitioner(None) # When init from val instead of callable initializer, the shape is expected to # be None, not <unknown> or any fully defined shape. shape = shape if callable(val) else None if resource_variable_ops.is_resource_variable(primary): use_resource = True elif isinstance(primary, variables.RefVariable): use_resource = False else: use_resource = None slot = variable_scope.get_variable(scope, initializer=val, trainable=False, use_resource=use_resource, shape=shape, dtype=dtype, validate_shape=validate_shape) variable_scope.get_variable_scope().set_partitioner(current_partitioner) # pylint: disable=protected-access if isinstance(primary, variables.Variable) and primary._save_slice_info: # Primary is a partitioned variable, so we need to also indicate that # the slot is a partitioned variable. Slots have the same partitioning # as their primaries. # For examples when using AdamOptimizer in linear model, slot.name # here can be "linear//weights/Adam:0", while primary.op.name is # "linear//weight". We want to get 'Adam' as real_slot_name, so we # remove "'linear//weight' + '/'" and ':0'. real_slot_name = slot.name[len(primary.op.name + "/"):-2] slice_info = primary._save_slice_info # support slot's shape not same as primary's shape # example: primary's shape = [10, 20, 30], slot's shape = # None, [], [10], [10, 20] or [10, 20, 30] is allowed # slot's shape = None or [10, 20, 30], set slot's slice_info same as primary # slot's shape = [], don't set slot's slice_info # slot's shape = [10] or [10, 20], set slot's slice_info according to ndims n = slot.shape.ndims if n is None or n > 0: slot._set_save_slice_info( variables.Variable.SaveSliceInfo( slice_info.full_name + "/" + real_slot_name, slice_info.full_shape[:n], slice_info.var_offset[:n], slice_info.var_shape[:n])) # pylint: enable=protected-access # Copy XLA sharding attributes from the primary if the slot variable has the # same rank as the primary. def _has_same_rank(primary_shape, slot_shape): return (primary_shape.rank is not None and slot_shape.rank is not None and primary_shape.rank == slot_shape.rank) if copy_xla_sharding and _has_same_rank(primary.shape, slot.shape): slot = xla_sharding.copy_sharding(primary, slot, use_sharding_op=False) return slot