def testStackedVarWrapperWithManualSharding(self): with tf.Graph().as_default(): var = tf.get_variable('v2', shape=[8, 16], dtype=tf.float32) xla_sharding.split(var, 0, num_devices=8, use_sharding_op=False) wrapper = var_tmp_wrappers.StackedVarWrapperWithManualSharding(var) ones = tf.ones_like(wrapper) wrapper.assign(ones) wrapper.assign_add(ones) wrapper.assign_sub(ones) self.assertEqual(ones.shape, [16])
def testCreateSlotWithCustomSplitXlaSharding(self): # slot_creator is used only in optimizer V1. # We insert our own custom split XLA sharding that overrides the SPMD # sharding copied over by the slot_creator. with ops.Graph().as_default(), self.cached_session(): v = variables.Variable([1.0, 2.5, 10.0, 15.1], name="var") v = xla_sharding.mesh_split(v, np.array([0, 1]), [0], use_sharding_op=False) with ops.control_dependencies(None): slot = slot_creator.create_zeros_slot(v, name="slot", dtype=dtypes.float64, copy_xla_sharding=True) slot = xla_sharding.split(slot, split_dimension=0, num_devices=4, use_sharding_op=False) self.assertNotEqual(xla_sharding.get_tensor_sharding(v), xla_sharding.get_tensor_sharding(slot)) slot_sharding = xla_sharding.get_tensor_sharding(slot) slot_proto = xla_data_pb2.OpSharding() slot_proto.ParseFromString(slot_sharding) self.assertEqual( slot_proto, xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.OTHER, tile_assignment_dimensions=[4], tile_assignment_devices=range(4)))
def split_helper(tensor): self.assertIsNone(xla_sharding.get_tensor_sharding(tensor)) split_tensor = xla_sharding.split(tensor, 2, 3) self.assertIsInstance(split_tensor, ops.Tensor) split_sharding = xla_sharding.get_tensor_sharding(split_tensor) split_shape = xla_sharding.get_sharding_tile_shape(split_sharding) expected_shape = [1, 1, 3] self.assertEqual(expected_shape, split_shape) return split_tensor
def _sharding(self, x): if self._xla_num_partitions: # The current partitions are hard coded for Transformer/BERT like # models. TODO: make it more general for other models. if len(x.get_shape()) == 3: x = xla_sharding.split(x, 1, self._xla_num_partitions, use_sharding_op=True) if len(x.get_shape()) == 2: if x.get_shape().as_list()[0] < x.get_shape().as_list()[1]: x = xla_sharding.split(x, 1, self._xla_num_partitions, use_sharding_op=True) else: x = xla_sharding.split(x, 0, self._xla_num_partitions, use_sharding_op=True) return x
def copy_helper(tensor): tensor_src = array_ops.identity(tensor) tensor_src = xla_sharding.split(tensor, 2, 3) sharding_src = xla_sharding.get_tensor_sharding(tensor_src) shape_src = xla_sharding.get_sharding_tile_shape(sharding_src) self.assertEqual([1, 1, 3], shape_src) tensor_dest = array_ops.identity(tensor) self.assertIsNone(xla_sharding.get_tensor_sharding(tensor_dest)) xla_sharding.copy_sharding(tensor_src, tensor_dest) sharding_dest = xla_sharding.get_tensor_sharding(tensor_dest) shape_dest = xla_sharding.get_sharding_tile_shape(sharding_dest) self.assertEqual([1, 1, 3], shape_dest) return tensor_dest
def Split(x, split_dimension, num_devices, use_sharding_op=True, input_shape=None): """Wrapper for xla_sharding.split. Args: x: Tensor to annotate. split_dimension: xla_sharding.split arg. num_devices: xla_sharding.split arg. use_sharding_op: If true, adds a sharding op to set the sharding: tensor = gen_xla_ops.xla_sharding(tensor) See http://cs/search/?q=XlaSharding+file:xla_ops.cc hyouklee@: use_sharding_op=False "It adds the sharding attribute to the op itself. The outcome is that, that information could be lost by TF graph transformations. Also, directly attaching the sharding annotation to the op caused some compilation failures in the past (due to incompatible shardings), so the plan is to make use_sharding_op to be the default." "The only case I would set it to False today is when annotating weights. Weight annotation does some special handling, so there may be some changes needed in that logic if we add separate sharding op." input_shape: The shape of the original tensor. Returns: Tensor conditionally annotated with sharding. """ if not py_utils.use_tpu() or num_devices is None or not num_devices > 1: return x return xla_sharding.split( x, split_dimension, num_devices, input_shape=input_shape, use_sharding_op=use_sharding_op, )
def eval_step_fn(params, model): """Build `step_fn` for eval.""" dtypes = [ tf.bfloat16 if params.use_bfloat16 else tf.float32, tf.float32, tf.float32 ] batch_size = params.eval_batch_size // params.num_replicas image_size = (params.eval_image_size if 'eval_image_size' in params else params.image_size) shapes = [[batch_size, image_size, image_size, 3], [batch_size, params.num_classes], [batch_size]] if params.use_xla_sharding and params.num_cores_per_replica > 1: q = tpu_feed._PartitionedInfeedQueue( number_of_tuple_elements=3, host_id=0, input_partition_dims=[[1, 1, params.num_cores_per_replica, 1], [1, 1], [1]], device_assignment=params.device_assignment) q.set_tuple_types(dtypes) q.set_tuple_shapes(shapes) images, labels, mask = q.generate_dequeue_op() images = xla_sharding.split(images, 2, params.num_cores_per_replica) else: with tf.device(tf.tpu.core(0)): images, labels, mask = tf.raw_ops.InfeedDequeueTuple(dtypes=dtypes, shapes=shapes) if len(labels.shape) > 1: # `labels` is one_hot. turn it to `int.32` labels = tf.argmax(labels, axis=-1, output_type=tf.int32) labels = tf.expand_dims(labels, axis=-1) _ = tf.train.get_or_create_global_step() with tf.variable_scope(MODEL_SCOPE): logits = model(images, training=False) logits = tf.cast(logits, tf.float32) return logits, labels, mask
def split_helper(tensor): split_tensor = xla_sharding.split(tensor, 0, 8) return split_tensor
def step_fn(self, params, model): """Separate implementation.""" train_batch_size = params.train_batch_size num_replicas = params.num_replicas uda_data = params.uda_data batch_size = train_batch_size // num_replicas dtypes = [ tf.bfloat16 if params.use_bfloat16 else tf.float32, tf.float32, tf.bfloat16 if params.use_bfloat16 else tf.float32, tf.bfloat16 if params.use_bfloat16 else tf.float32] shapes = [ [batch_size, params.image_size, params.image_size, 3], [batch_size, params.num_classes], [batch_size*params.uda_data, params.image_size, params.image_size, 3], [batch_size*params.uda_data, params.image_size, params.image_size, 3]] if params.use_xla_sharding and params.num_cores_per_replica > 1: q = tpu_feed._PartitionedInfeedQueue( number_of_tuple_elements=4, host_id=0, input_partition_dims=[[1, 1, params.num_cores_per_replica, 1], [1, 1], [1, 1, params.num_cores_per_replica, 1], [1, 1, params.num_cores_per_replica, 1],], device_assignment=params.device_assignment) q.set_tuple_types(dtypes) q.set_tuple_shapes(shapes) l_images, l_labels, u_images_ori, u_images_aug = q.generate_dequeue_op() l_images = xla_sharding.split(l_images, 2, params.num_cores_per_replica) u_images_ori = xla_sharding.split(u_images_ori, 2, params.num_cores_per_replica) u_images_aug = xla_sharding.split(u_images_aug, 2, params.num_cores_per_replica) else: with tf.device(tf.tpu.core(0)): (l_images, l_labels, u_images_ori, u_images_aug) = tf.raw_ops.InfeedDequeueTuple(dtypes=dtypes, shapes=shapes) global_step = tf.train.get_or_create_global_step() num_replicas = tf.cast(params.num_replicas, tf.float32) all_images = tf.concat([l_images, u_images_ori, u_images_aug], axis=0) # all calls to teacher with tf.variable_scope('teacher', reuse=tf.AUTO_REUSE): logits, labels, masks, cross_entropy = UDA.build_uda_cross_entropy( params, model, all_images, l_labels) # 1st call to student with tf.variable_scope(MODEL_SCOPE): u_aug_and_l_images = tf.concat([u_images_aug, l_images], axis=0) logits['s_on_u_aug_and_l'] = model(u_aug_and_l_images, training=True) logits['s_on_u'], logits['s_on_l_old'] = tf.split( logits['s_on_u_aug_and_l'], [u_images_aug.shape[0].value, l_images.shape[0].value], axis=0) # for backprop cross_entropy['s_on_u'] = tf.losses.softmax_cross_entropy( onehot_labels=tf.stop_gradient(tf.nn.softmax(logits['u_aug'], -1)), logits=logits['s_on_u'], label_smoothing=params.label_smoothing, reduction=tf.losses.Reduction.NONE) cross_entropy['s_on_u'] = tf.reduce_sum(cross_entropy['s_on_u']) / float( train_batch_size*uda_data) # for Taylor cross_entropy['s_on_l_old'] = tf.losses.softmax_cross_entropy( onehot_labels=labels['l'], logits=logits['s_on_l_old'], reduction=tf.losses.Reduction.SUM) cross_entropy['s_on_l_old'] = tf.tpu.cross_replica_sum( cross_entropy['s_on_l_old']) / float(train_batch_size) shadow = tf.get_variable( name='cross_entropy_old', shape=[], trainable=False, dtype=tf.float32) shadow_update = tf.assign(shadow, cross_entropy['s_on_l_old']) w_s = {} g_s = {} g_n = {} lr = {} optim = {} w_s['s'] = [w for w in tf.trainable_variables() if w.name.lower().startswith(MODEL_SCOPE)] g_s['s_on_u'] = tf.gradients(cross_entropy['s_on_u'], w_s['s']) # g_s['s_on_u'] = [tf.tpu.cross_replica_sum(g) for g in g_s['s_on_u']] lr['s'] = common_utils.get_learning_rate( params, initial_lr=params.mpl_student_lr, num_warmup_steps=params.mpl_student_lr_warmup_steps, num_wait_steps=params.mpl_student_lr_wait_steps) lr['s'], optim['s'] = common_utils.get_optimizer( params, learning_rate=lr['s']) optim['s']._create_slots(w_s['s']) # pylint: disable=protected-access update_ops = [op for op in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if op.name.startswith(f'train/{MODEL_SCOPE}/')] with tf.control_dependencies(update_ops + [shadow_update]): g_s['s_on_u'] = common_utils.add_weight_decay( params, w_s['s'], g_s['s_on_u']) g_s['s_on_u'], g_n['s_on_u'] = tf.clip_by_global_norm( g_s['s_on_u'], params.grad_bound) train_op = optim['s'].apply_gradients(zip(g_s['s_on_u'], w_s['s'])) with tf.control_dependencies([train_op]): ema_train_op = common_utils.setup_ema( params, name_scope=f'{MODEL_SCOPE}/{model.name}') # 2nd call to student with tf.control_dependencies([ema_train_op]): with tf.variable_scope(MODEL_SCOPE, reuse=tf.AUTO_REUSE): logits['s_on_l_new'] = model(l_images, training=True) cross_entropy['s_on_l_new'] = tf.losses.softmax_cross_entropy( onehot_labels=labels['l'], logits=logits['s_on_l_new'], reduction=tf.losses.Reduction.SUM) cross_entropy['s_on_l_new'] = tf.tpu.cross_replica_sum( cross_entropy['s_on_l_new']) / float(train_batch_size) dot_product = cross_entropy['s_on_l_new'] - shadow # dot_product = tf.clip_by_value( # dot_product, # clip_value_min=-params.mpl_dot_product_bound, # clip_value_max=params.mpl_dot_product_bound) moving_dot_product = tf.get_variable( 'moving_dot_product', shape=[], trainable=False, dtype=tf.float32) moving_dot_product_update = tf.assign_sub( moving_dot_product, 0.01 * (moving_dot_product - dot_product)) with tf.control_dependencies([moving_dot_product_update]): dot_product = dot_product - moving_dot_product dot_product = tf.stop_gradient(dot_product) cross_entropy['mpl'] = tf.losses.softmax_cross_entropy( onehot_labels=tf.stop_gradient(tf.nn.softmax(logits['u_aug'], axis=-1)), logits=logits['u_aug'], reduction=tf.losses.Reduction.NONE) cross_entropy['mpl'] = tf.reduce_sum(cross_entropy['mpl']) / float( train_batch_size*uda_data) # teacher train op uda_weight = params.uda_weight * tf.minimum( 1., tf.cast(global_step, tf.float32) / float(params.uda_steps)) teacher_loss = (cross_entropy['u'] * uda_weight + cross_entropy['l'] + cross_entropy['mpl'] * dot_product) w_s['t'] = [w for w in tf.trainable_variables() if 'teacher' in w.name] g_s['t'] = tf.gradients(teacher_loss, w_s['t']) g_s['t'] = common_utils.add_weight_decay(params, w_s['t'], g_s['t']) g_s['t'], g_n['t'] = tf.clip_by_global_norm(g_s['t'], params.grad_bound) lr['t'] = common_utils.get_learning_rate( params, initial_lr=params.mpl_teacher_lr, num_warmup_steps=params.mpl_teacher_lr_warmup_steps) lr['t'], optim['t'] = common_utils.get_optimizer(params, learning_rate=lr['t']) teacher_train_op = optim['t'].apply_gradients(zip(g_s['t'], w_s['t']), global_step=global_step) with tf.control_dependencies([teacher_train_op]): logs = collections.OrderedDict() logs['global_step'] = tf.cast(global_step, tf.float32) logs['cross_entropy/student_on_u'] = cross_entropy['s_on_u'] logs['cross_entropy/student_on_l'] = (cross_entropy['s_on_l_new'] / num_replicas) logs['cross_entropy/teacher_on_u'] = cross_entropy['u'] logs['cross_entropy/teacher_on_l'] = cross_entropy['l'] logs['lr/student'] = tf.identity(lr['s']) / num_replicas logs['lr/teacher'] = tf.identity(lr['t']) / num_replicas logs['mpl/dot_product'] = dot_product / num_replicas logs['mpl/moving_dot_product'] = moving_dot_product / num_replicas logs['uda/u_ratio'] = tf.reduce_mean(masks['u']) / num_replicas logs['uda/l_ratio'] = tf.reduce_mean(masks['l']) / num_replicas logs['uda/weight'] = uda_weight / num_replicas tensors = [tf.expand_dims(t, axis=0) for t in logs.values()] self.step_info = {k: [tf.float32, [1]] for k in logs.keys()} def outfeed(tensors): with tf.device(tf.tpu.core(params.num_cores_per_replica-1)): return tf.raw_ops.OutfeedEnqueueTuple(inputs=tensors) outfeed_enqueue_op = tf.cond( common_utils.should_log(params), lambda: outfeed(tensors), tf.no_op) return outfeed_enqueue_op
def step_fn(self, params, model): """Separate implementation.""" train_batch_size = params.train_batch_size num_replicas = params.num_replicas batch_size = train_batch_size // num_replicas dtypes = [ tf.bfloat16 if params.use_bfloat16 else tf.float32, tf.float32, tf.bfloat16 if params.use_bfloat16 else tf.float32, tf.bfloat16 if params.use_bfloat16 else tf.float32] shapes = [ [batch_size, params.image_size, params.image_size, 3], [batch_size, params.num_classes], [batch_size*params.uda_data, params.image_size, params.image_size, 3], [batch_size*params.uda_data, params.image_size, params.image_size, 3]] if params.use_xla_sharding and params.num_cores_per_replica > 1: q = tpu_feed._PartitionedInfeedQueue( number_of_tuple_elements=4, host_id=0, input_partition_dims=[[1, 1, params.num_cores_per_replica, 1], [1, 1], [1, 1, params.num_cores_per_replica, 1], [1, 1, params.num_cores_per_replica, 1],], device_assignment=params.device_assignment) q.set_tuple_types(dtypes) q.set_tuple_shapes(shapes) l_images, l_labels, u_images_ori, u_images_aug = q.generate_dequeue_op() l_images = xla_sharding.split(l_images, 2, params.num_cores_per_replica) u_images_ori = xla_sharding.split(u_images_ori, 2, params.num_cores_per_replica) u_images_aug = xla_sharding.split(u_images_aug, 2, params.num_cores_per_replica) else: with tf.device(tf.tpu.core(0)): (l_images, l_labels, u_images_ori, u_images_aug) = tf.raw_ops.InfeedDequeueTuple(dtypes=dtypes, shapes=shapes) all_images = tf.concat([l_images, u_images_ori, u_images_aug], axis=0) global_step = tf.train.get_or_create_global_step() num_replicas = tf.cast(params.num_replicas, tf.float32) with tf.variable_scope(MODEL_SCOPE, reuse=tf.AUTO_REUSE): _, _, masks, cross_entropy = UDA.build_uda_cross_entropy( params, model, all_images, l_labels) l2_reg_rate = tf.cast(params.weight_decay / params.num_replicas, tf.float32) weight_dec = common_utils.get_l2_loss() uda_weight = params.uda_weight * tf.minimum( 1., tf.cast(global_step, tf.float32) / float(params.uda_steps)) total_loss = (cross_entropy['u'] * uda_weight + cross_entropy['l'] + weight_dec * l2_reg_rate) variables = tf.trainable_variables() gradients = tf.gradients(total_loss, variables) gradients = [tf.tpu.cross_replica_sum(g) for g in gradients] gradients, grad_norm = tf.clip_by_global_norm(gradients, params.grad_bound) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) learning_rate, optimizer = common_utils.get_optimizer(params) with tf.control_dependencies(update_ops): train_op = optimizer.apply_gradients(zip(gradients, variables), global_step=global_step) with tf.control_dependencies([train_op]): ema_train_op = common_utils.setup_ema( params, f'{MODEL_SCOPE}/{model.name}') with tf.control_dependencies([ema_train_op]): logs = collections.OrderedDict() logs['global_step'] = tf.cast(global_step, tf.float32) logs['loss/total'] = total_loss logs['loss/cross_entropy'] = cross_entropy['l'] logs['loss/lr'] = tf.identity(learning_rate) / num_replicas logs['loss/grad_norm'] = tf.identity(grad_norm) / num_replicas logs['loss/weight_dec'] = weight_dec / num_replicas logs['uda/cross_entropy'] = cross_entropy['u'] logs['uda/u_ratio'] = tf.reduce_mean(masks['u']) / num_replicas logs['uda/l_ratio'] = tf.reduce_mean(masks['l']) / num_replicas logs['uda/weight'] = uda_weight / num_replicas tensors = [tf.expand_dims(t, axis=0) for t in logs.values()] self.step_info = {k: [tf.float32, [1]] for k in logs.keys()} outfeed_enqueue_op = tf.cond( common_utils.should_log(params), lambda: tf.raw_ops.OutfeedEnqueueTuple(inputs=tensors), tf.no_op) return outfeed_enqueue_op
def step_fn(self, params, model): """A single step for supervised learning.""" batch_size = params.train_batch_size // params.num_replicas dtypes = [tf.bfloat16 if params.use_bfloat16 else tf.float32, tf.float32] shapes = [[batch_size, params.image_size, params.image_size, 3], [batch_size, params.num_classes]] if params.use_xla_sharding and params.num_cores_per_replica > 1: q = tpu_feed._PartitionedInfeedQueue( number_of_tuple_elements=2, host_id=0, input_partition_dims=[[1, 1, params.num_cores_per_replica, 1], [1, 1]], device_assignment=params.device_assignment) q.set_tuple_types(dtypes) q.set_tuple_shapes(shapes) images, labels = q.generate_dequeue_op() images = xla_sharding.split(images, 2, params.num_cores_per_replica) else: with tf.device(tf.tpu.core(0)): images, labels = tf.raw_ops.InfeedDequeueTuple(dtypes=dtypes, shapes=shapes) if labels.dtype == tf.int32: labels = tf.one_hot(labels, depth=params.num_classes, dtype=tf.float32) global_step = tf.train.get_or_create_global_step() train_batch_size = tf.cast(params.train_batch_size, tf.float32) num_replicas = tf.cast(params.num_replicas, tf.float32) with tf.variable_scope(MODEL_SCOPE): logits = model(images, training=True) if 'noisy_student' in params.dataset_name.lower(): cross_entropy = labels * tf.nn.log_softmax(logits, axis=-1) cross_entropy = tf.reduce_sum(-cross_entropy) / train_batch_size else: cross_entropy = tf.losses.softmax_cross_entropy( onehot_labels=labels, logits=logits, label_smoothing=params.label_smoothing, reduction=tf.losses.Reduction.SUM) / train_batch_size l2_reg_rate = tf.cast(params.weight_decay / params.num_replicas, tf.float32) weight_dec = common_utils.get_l2_loss() total_loss = cross_entropy + weight_dec * l2_reg_rate variables = tf.trainable_variables() gradients = tf.gradients(total_loss, variables) gradients = [tf.tpu.cross_replica_sum(g) for g in gradients] gradients, grad_norm = tf.clip_by_global_norm(gradients, params.grad_bound) learning_rate, optimizer = common_utils.get_optimizer(params) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) train_op = tf.cond( tf.math.is_finite(grad_norm), lambda: optimizer.apply_gradients(zip(gradients, variables), global_step=global_step), tf.no_op) with tf.control_dependencies(update_ops + [train_op]): ema_train_op = common_utils.setup_ema(params, f'{MODEL_SCOPE}/{model.name}') with tf.control_dependencies([ema_train_op]): logs = collections.OrderedDict() logs['global_step'] = tf.cast(global_step, tf.float32) logs['loss/total'] = total_loss logs['loss/weight_decay'] = weight_dec / num_replicas logs['loss/cross_entropy'] = cross_entropy logs['loss/lr'] = tf.identity(learning_rate) / num_replicas logs['loss/grad_norm'] = grad_norm / num_replicas tensors = [tf.expand_dims(t, axis=0) for t in logs.values()] self.step_info = {k: [tf.float32, [1]] for k in logs.keys()} outfeed_enqueue_op = tf.cond( common_utils.should_log(params), lambda: tf.raw_ops.OutfeedEnqueueTuple(inputs=tensors), tf.no_op) return outfeed_enqueue_op