def enqueue_fn(host_device=host, input_index=i, device_ordinals=tpus): """Docs.""" worker_infeed_ops = [] with tf.device(host_device): dataset = build_eval_dataset( params, batch_size=params.eval_batch_size // num_inputs, num_workers=num_inputs, worker_index=input_index) inputs = tf.data.make_one_shot_iterator(dataset).get_next() if params.use_xla_sharding and params.num_cores_per_replica > 1: inputs, partition_dims = pad_inputs_for_xla_sharding( params, inputs) num_splits = len(device_ordinals) if len(device_ordinals) > 1: inputs = [ tf.split(v, num_splits, 0) for v in inputs ] else: inputs = [[v] for v in inputs] q = tpu_feed._PartitionedInfeedQueue( number_of_tuple_elements=len(inputs), host_id=int( host_device.split('/task:')[-1].split('/')[0]), input_partition_dims=partition_dims, device_assignment=dev_assign) inputs = [[v[i] for v in inputs] for i in range(num_splits)] worker_infeed_ops.extend( q.generate_enqueue_ops(inputs)) else: num_splits = len(device_ordinals) if len(device_ordinals) > 1: inputs = [ tf.split(v, num_splits, 0) for v in inputs ] else: inputs = [[v] for v in inputs] input_shapes = [v[0].shape for v in inputs] for j, device_ordinal in enumerate(device_ordinals): worker_infeed_ops.append( tf.raw_ops.InfeedEnqueueTuple( inputs=[v[j] for v in inputs], shapes=input_shapes, device_ordinal=device_ordinal)) return worker_infeed_ops
def test_infeed_uneven_partition(self): """Tests uneven infeed tensors partition.""" ds = device_assignment( self._topology_2x2x2, num_replicas=1, computation_shape=[2, 2, 1, 2]) input_partition_dims = [[4, 2]] # pylint: disable=protected-access partitioned_infeed = tpu_feed._PartitionedInfeedQueue( number_of_tuple_elements=1, host_id=0, input_partition_dims=input_partition_dims, device_assignment=ds) x = array_ops.zeros((14, 5)) tensors = partitioned_infeed._check_dims_and_partition_or_replicate_on_host( x, dims=input_partition_dims[0]) self.assertEqual(8, len(tensors)) self.assertEqual((2, 2), tensors[-1].shape)
def test_infeed_tailing_zero_partition(self): """Tests infeed tensors partition which causes zero-size tensors.""" ds = device_assignment( self._topology_2x2x2, num_replicas=1, computation_shape=[1, 2, 1, 2]) input_partition_dims = [[4, 1]] # pylint: disable=protected-access partitioned_infeed = tpu_feed._PartitionedInfeedQueue( number_of_tuple_elements=1, host_id=0, input_partition_dims=input_partition_dims, device_assignment=ds) x = array_ops.zeros((5, 5)) tensors = partitioned_infeed._check_dims_and_partition_or_replicate_on_host( x, dims=input_partition_dims[0]) self.assertEqual(4, len(tensors)) self.assertEqual((1, 5), tensors[2].shape) self.assertEqual((0, 5), tensors[3].shape)
def _config_infeed(self, num_partitions, device_assignment, batch_size, key_size=2, return_tgt_mask=False, use_partitioned_infeed_queue=False): """Config the infeed ops and args.""" zero_batch = get_zero_batch(batch_size=batch_size, max_len=self._prefix_max_len, key_size=key_size, return_tgt_mask=return_tgt_mask) host_device = device_assignment.host_device(replica=0, job=self._tpu) host_id = int(host_device.split('/task:')[1].split('/device:')[0]) input_partition_dims = [[num_partitions] + [1] * (len(x.shape) - 1) for x in zero_batch] if use_partitioned_infeed_queue: infeed = tpu_feed._PartitionedInfeedQueue( # pylint: disable=protected-access number_of_tuple_elements=len(zero_batch), host_id=host_id, input_partition_dims=input_partition_dims, device_assignment=device_assignment) else: infeed = tpu_feed.InfeedQueue( number_of_tuple_elements=len(zero_batch)) self.infeed_args = [] for x in zero_batch: p = tf.placeholder(tf.as_dtype(x.dtype), shape=x.shape) self.infeed_args += [p] if use_partitioned_infeed_queue: self.infeed_op = infeed.generate_enqueue_ops([self.infeed_args]) else: self.infeed_op = infeed.split_inputs_and_generate_enqueue_ops( self.infeed_args, device_assignment=device_assignment) return infeed
def eval_step_fn(params, model): """Build `step_fn` for eval.""" dtypes = [ tf.bfloat16 if params.use_bfloat16 else tf.float32, tf.float32, tf.float32 ] batch_size = params.eval_batch_size // params.num_replicas image_size = (params.eval_image_size if 'eval_image_size' in params else params.image_size) shapes = [[batch_size, image_size, image_size, 3], [batch_size, params.num_classes], [batch_size]] if params.use_xla_sharding and params.num_cores_per_replica > 1: q = tpu_feed._PartitionedInfeedQueue( number_of_tuple_elements=3, host_id=0, input_partition_dims=[[1, 1, params.num_cores_per_replica, 1], [1, 1], [1]], device_assignment=params.device_assignment) q.set_tuple_types(dtypes) q.set_tuple_shapes(shapes) images, labels, mask = q.generate_dequeue_op() images = xla_sharding.split(images, 2, params.num_cores_per_replica) else: with tf.device(tf.tpu.core(0)): images, labels, mask = tf.raw_ops.InfeedDequeueTuple(dtypes=dtypes, shapes=shapes) if len(labels.shape) > 1: # `labels` is one_hot. turn it to `int.32` labels = tf.argmax(labels, axis=-1, output_type=tf.int32) labels = tf.expand_dims(labels, axis=-1) _ = tf.train.get_or_create_global_step() with tf.variable_scope(MODEL_SCOPE): logits = model(images, training=False) logits = tf.cast(logits, tf.float32) return logits, labels, mask
def CreateTpuFeeds(self): """Creates the TPU infeed queue from preprocessed batch.""" p = self.params cluster = self.cluster num_tpu_hosts = cluster.num_tpu_hosts num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts tf.logging.info( 'CreateTPUFeeds num_splits_per_client={} ' 'num_devices_per_split={} num_tpu_hosts={} use_per_host_infeed={}'. format(cluster.num_splits_per_client, cluster.num_devices_per_split, num_tpu_hosts, p.use_per_host_infeed)) assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts) if (cluster.num_devices_per_split > num_cores_per_host and p.use_per_host_infeed): tf.logging.fatal( 'Doesn\'t support per host infeed mode when ' 'num_devices_per_split({}) > num_cores_per_host({})'.format( cluster.num_devices_per_split, num_cores_per_host)) num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1 with py_utils.outside_all_rewrites(): assert py_utils.use_tpu() assert not self._made_tpu_infeed shards = tpu_function.get_tpu_context( ).number_of_shards // num_infeed_hosts tf.logging.info('shards {}'.format(shards)) input_ops_list = [] queues = [] tpu_embedding_collection = tf.get_collection( py_utils.TPU_EMBEDDING) tpu_embedding = (tpu_embedding_collection[0] if tpu_embedding_collection else None) if num_tpu_hosts > 1 and tpu_embedding is not None: if not p.use_per_host_infeed: tf.logging.fatal( 'TPU Embedding must be used with per_host_infeed with multiple ' 'TPU host topologies.') tpu_emb_input_keys = (list( tpu_embedding.feature_to_config_dict.keys()) if tpu_embedding is not None else []) tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys) batch = None for task_id in range(num_infeed_hosts): host_device = '/task:{}/device:CPU:0'.format(task_id) with tf.device(host_device): batch = self.GetPreprocessedInputBatch() if isinstance(batch, py_utils.NestedMap): # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU. # Note that when MultiTaskData is used, bucket_keys will be at the # second level of the dictionary. batch = batch.FilterKeyVal( lambda k, _: not k.endswith('bucket_keys')) tf.logging.info('host_device: %s, batch: %r', host_device, batch) if tpu_embedding is not None: enqueue_dict_per_core = [ {} for _ in range(tpu_embedding.num_cores_per_host) ] num_cores_per_host = tpu_embedding.num_cores_per_host for key in tpu_emb_input_keys: feat = batch[key] tpu_emb_feat_splitted = tf.split( feat, num_cores_per_host) for core, split in enumerate( tpu_emb_feat_splitted): # Dense to sparse. Note the assumption of a padding id. sample_indices = tf.where( tf.not_equal(split, -1)) embedding_indices = tf.gather_nd( split, sample_indices) enqueue_data = tpu_embedding_lib.EnqueueData( embedding_indices, sample_indices) enqueue_dict_per_core[core][key] = enqueue_data input_ops_list += tpu_embedding.generate_enqueue_ops( enqueue_dict_per_core) for k, x in batch.FlattenItems(): assert x.shape.is_fully_defined(), ( 'Shape must be fully defined: %s: %s' % (k, x)) # TODO(cwhipkey): if it's a string (or other type not supported on # TPU), drop it from feeding and on the other end add in an op that # fails if used. shapes = batch.Transform(lambda x: x.shape).Flatten() dtypes = batch.Transform(lambda x: x.dtype).Flatten() tf.logging.info('host_device: %s infeed shapes: %r', host_device, shapes) tf.logging.info('host_device: %s infeed dtypes: %r', host_device, dtypes) if p.use_partitioned_infeed_queue: device_assignment = py_utils.GetTpuDeviceAssignment() host_device = device_assignment.host_device( replica=0, job=tf.flags.FLAGS.tf_master) host_id = int( host_device.split('/task:')[1].split('/device:') [0]) tf.logging.info('host_id: {} host_device: {}'.format( host_id, host_device)) q = tpu_feed._PartitionedInfeedQueue( # pylint: disable=protected-access number_of_tuple_elements=len(dtypes), device_assignment=device_assignment, host_id=host_id, input_partition_dims=[[p.num_partitions, 1] for _ in dtypes], tuple_types=dtypes, tuple_shapes=shapes) else: q = tpu_feed.InfeedQueue(tuple_types=dtypes, tuple_shapes=shapes) assert shards is not None q.set_number_of_shards(shards) queues.append(q) tf.logging.info('q=%r', q) if p.use_partitioned_infeed_queue: input_ops = q.generate_enqueue_ops([batch.Flatten()]) elif p.use_per_host_infeed: # TODO(ylc/zhifengc): Add this to a policy module and test it. def TPUOrdinalFunction(shard_index_in_host): device_assignment = py_utils.GetTpuDeviceAssignment( ) if device_assignment: # We put both enqueue/dequeue ops at core 0 in each replica. replica = device_assignment.lookup_replicas( task_id, 0)[shard_index_in_host] # pylint: disable=cell-var-from-loop return device_assignment.tpu_ordinal( replica=replica) else: return shard_index_in_host input_ops = q.split_inputs_and_generate_enqueue_ops( batch.Flatten(), placement_function=lambda x: host_device, # pylint: disable=cell-var-from-loop tpu_ordinal_function=TPUOrdinalFunction) else: input_ops = q.split_inputs_and_generate_enqueue_ops( batch.Flatten(), device_assignment=py_utils.GetTpuDeviceAssignment( )) input_ops_list += input_ops tf.logging.info('input_ops_list %s', input_ops_list) tpu_infeed_op = tf.group(*input_ops_list) self._made_tpu_infeed = True # Let trainer.py use multiple threads to drive the infeed op. for _ in range(p.tpu_infeed_parallelism): tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op) self._tpu_infeed_op = tpu_infeed_op with tf.device(tf.tpu.core(0)): tensors = queues[0].generate_dequeue_op() return batch.Pack(tensors)
def CreateTpuEnqueueOps(self): """Create the host-side enqueue ops. This should be called in an outer non-TPU context. """ assert not self._tpu_queues, ( 'CreateTpuEnqueueOps should only be called ' 'once.') self._tpu_queues = [] p = self.params cluster = self.cluster num_tpu_hosts = cluster.num_tpu_hosts num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts tf.logging.info( 'CreateTpuEnqueueOps num_splits_per_client={} ' 'num_devices_per_split={} num_tpu_hosts={} use_per_host_infeed={}'. format(cluster.num_splits_per_client, cluster.num_devices_per_split, num_tpu_hosts, p.use_per_host_infeed)) assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts) if (cluster.num_devices_per_split > num_cores_per_host and p.use_per_host_infeed): tf.logging.fatal( 'Doesn\'t support per host infeed mode when ' 'num_devices_per_split({}) > num_cores_per_host({})'.format( cluster.num_devices_per_split, num_cores_per_host)) num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1 shards = (cluster.total_worker_devices // num_infeed_hosts) // cluster.num_devices_per_split tf.logging.info('shards {}'.format(shards)) input_ops_list = [] tpu_embedding_collection = tf.get_collection(py_utils.TPU_EMBEDDING) tpu_embedding = (tpu_embedding_collection[0] if tpu_embedding_collection else None) if num_tpu_hosts > 1 and tpu_embedding is not None: if not p.use_per_host_infeed: tf.logging.fatal( 'TPU Embedding must be used with per_host_infeed with multiple ' 'TPU host topologies.') tpu_emb_input_keys = (list(tpu_embedding.feature_to_config_dict.keys()) if tpu_embedding is not None else []) tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys) tf.logging.info('num_infeed_hosts: %d', num_infeed_hosts) for task_id in range(num_infeed_hosts): host_device = '/task:{}/device:CPU:0'.format(task_id) with tf.device(host_device): self._batch = self.GetPreprocessedInputBatch() if isinstance(self._batch, py_utils.NestedMap): # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU. # Note that when MultiTaskData is used, bucket_keys will be at the # second level of the dictionary. self._batch = self._batch.FilterKeyVal( lambda k, _: not k.endswith('bucket_keys')) tf.logging.info('host_device: %s, batch: %r', host_device, self._batch) for k, x in self._batch.FlattenItems(): assert x.shape.is_fully_defined(), ( 'Shape must be fully defined: %s: %s' % (k, x)) # TODO(cwhipkey): if it's a string (or other type not supported on # TPU), drop it from feeding and on the other end add in an op that # fails if used. shapes = self._batch.Transform(lambda x: x.shape).Flatten() dtypes = self._batch.Transform(lambda x: x.dtype).Flatten() tf.logging.info('host_device: %s infeed shapes: %r', host_device, shapes) tf.logging.info('host_device: %s infeed dtypes: %r', host_device, dtypes) if p.use_partitioned_infeed_queue: device_assignment = py_utils.GetTpuDeviceAssignment() host_device = device_assignment.host_device( replica=0, job=tf.flags.FLAGS.tf_master) host_id = int( host_device.split('/task:')[1].split('/device:')[0]) tf.logging.info('host_id: {} host_device: {}'.format( host_id, host_device)) q = tpu_feed._PartitionedInfeedQueue( # pylint: disable=protected-access number_of_tuple_elements=len(dtypes), device_assignment=device_assignment, host_id=host_id, input_partition_dims=[[p.num_partitions] + [1] * (len(s) - 1) for s in shapes], tuple_types=dtypes, tuple_shapes=shapes) else: q = tpu_feed.InfeedQueue(tuple_types=dtypes, tuple_shapes=shapes) assert shards is not None q.set_number_of_shards(shards) self._tpu_queues.append(q) if p.use_partitioned_infeed_queue: input_ops = q.generate_enqueue_ops([self._batch.Flatten()]) elif p.use_per_host_infeed: # TODO(ylc/zhifengc): Add this to a policy module and test it. def TPUOrdinalFunction(shard_index_in_host): device_assignment = py_utils.GetTpuDeviceAssignment() if device_assignment: # We put both enqueue/dequeue ops at core 0 in each replica. replica = device_assignment.lookup_replicas( task_id, 0)[shard_index_in_host] # pylint: disable=cell-var-from-loop return device_assignment.tpu_ordinal( replica=replica) else: return shard_index_in_host input_ops = q.split_inputs_and_generate_enqueue_ops( self._batch.Flatten(), placement_function=lambda x: host_device, # pylint: disable=cell-var-from-loop tpu_ordinal_function=TPUOrdinalFunction) else: input_ops = q.split_inputs_and_generate_enqueue_ops( self._batch.Flatten(), device_assignment=py_utils.GetTpuDeviceAssignment()) input_ops_list += input_ops tf.logging.info('input_ops_list %s', input_ops_list) grouped_infeed_op = tf.group(*input_ops_list) self._tpu_infeed_op = [] for _ in range(p.tpu_infeed_parallelism): self._tpu_infeed_op.append(grouped_infeed_op)
def step_fn(self, params, model): """Separate implementation.""" train_batch_size = params.train_batch_size num_replicas = params.num_replicas uda_data = params.uda_data batch_size = train_batch_size // num_replicas dtypes = [ tf.bfloat16 if params.use_bfloat16 else tf.float32, tf.float32, tf.bfloat16 if params.use_bfloat16 else tf.float32, tf.bfloat16 if params.use_bfloat16 else tf.float32] shapes = [ [batch_size, params.image_size, params.image_size, 3], [batch_size, params.num_classes], [batch_size*params.uda_data, params.image_size, params.image_size, 3], [batch_size*params.uda_data, params.image_size, params.image_size, 3]] if params.use_xla_sharding and params.num_cores_per_replica > 1: q = tpu_feed._PartitionedInfeedQueue( number_of_tuple_elements=4, host_id=0, input_partition_dims=[[1, 1, params.num_cores_per_replica, 1], [1, 1], [1, 1, params.num_cores_per_replica, 1], [1, 1, params.num_cores_per_replica, 1],], device_assignment=params.device_assignment) q.set_tuple_types(dtypes) q.set_tuple_shapes(shapes) l_images, l_labels, u_images_ori, u_images_aug = q.generate_dequeue_op() l_images = xla_sharding.split(l_images, 2, params.num_cores_per_replica) u_images_ori = xla_sharding.split(u_images_ori, 2, params.num_cores_per_replica) u_images_aug = xla_sharding.split(u_images_aug, 2, params.num_cores_per_replica) else: with tf.device(tf.tpu.core(0)): (l_images, l_labels, u_images_ori, u_images_aug) = tf.raw_ops.InfeedDequeueTuple(dtypes=dtypes, shapes=shapes) global_step = tf.train.get_or_create_global_step() num_replicas = tf.cast(params.num_replicas, tf.float32) all_images = tf.concat([l_images, u_images_ori, u_images_aug], axis=0) # all calls to teacher with tf.variable_scope('teacher', reuse=tf.AUTO_REUSE): logits, labels, masks, cross_entropy = UDA.build_uda_cross_entropy( params, model, all_images, l_labels) # 1st call to student with tf.variable_scope(MODEL_SCOPE): u_aug_and_l_images = tf.concat([u_images_aug, l_images], axis=0) logits['s_on_u_aug_and_l'] = model(u_aug_and_l_images, training=True) logits['s_on_u'], logits['s_on_l_old'] = tf.split( logits['s_on_u_aug_and_l'], [u_images_aug.shape[0].value, l_images.shape[0].value], axis=0) # for backprop cross_entropy['s_on_u'] = tf.losses.softmax_cross_entropy( onehot_labels=tf.stop_gradient(tf.nn.softmax(logits['u_aug'], -1)), logits=logits['s_on_u'], label_smoothing=params.label_smoothing, reduction=tf.losses.Reduction.NONE) cross_entropy['s_on_u'] = tf.reduce_sum(cross_entropy['s_on_u']) / float( train_batch_size*uda_data) # for Taylor cross_entropy['s_on_l_old'] = tf.losses.softmax_cross_entropy( onehot_labels=labels['l'], logits=logits['s_on_l_old'], reduction=tf.losses.Reduction.SUM) cross_entropy['s_on_l_old'] = tf.tpu.cross_replica_sum( cross_entropy['s_on_l_old']) / float(train_batch_size) shadow = tf.get_variable( name='cross_entropy_old', shape=[], trainable=False, dtype=tf.float32) shadow_update = tf.assign(shadow, cross_entropy['s_on_l_old']) w_s = {} g_s = {} g_n = {} lr = {} optim = {} w_s['s'] = [w for w in tf.trainable_variables() if w.name.lower().startswith(MODEL_SCOPE)] g_s['s_on_u'] = tf.gradients(cross_entropy['s_on_u'], w_s['s']) # g_s['s_on_u'] = [tf.tpu.cross_replica_sum(g) for g in g_s['s_on_u']] lr['s'] = common_utils.get_learning_rate( params, initial_lr=params.mpl_student_lr, num_warmup_steps=params.mpl_student_lr_warmup_steps, num_wait_steps=params.mpl_student_lr_wait_steps) lr['s'], optim['s'] = common_utils.get_optimizer( params, learning_rate=lr['s']) optim['s']._create_slots(w_s['s']) # pylint: disable=protected-access update_ops = [op for op in tf.get_collection(tf.GraphKeys.UPDATE_OPS) if op.name.startswith(f'train/{MODEL_SCOPE}/')] with tf.control_dependencies(update_ops + [shadow_update]): g_s['s_on_u'] = common_utils.add_weight_decay( params, w_s['s'], g_s['s_on_u']) g_s['s_on_u'], g_n['s_on_u'] = tf.clip_by_global_norm( g_s['s_on_u'], params.grad_bound) train_op = optim['s'].apply_gradients(zip(g_s['s_on_u'], w_s['s'])) with tf.control_dependencies([train_op]): ema_train_op = common_utils.setup_ema( params, name_scope=f'{MODEL_SCOPE}/{model.name}') # 2nd call to student with tf.control_dependencies([ema_train_op]): with tf.variable_scope(MODEL_SCOPE, reuse=tf.AUTO_REUSE): logits['s_on_l_new'] = model(l_images, training=True) cross_entropy['s_on_l_new'] = tf.losses.softmax_cross_entropy( onehot_labels=labels['l'], logits=logits['s_on_l_new'], reduction=tf.losses.Reduction.SUM) cross_entropy['s_on_l_new'] = tf.tpu.cross_replica_sum( cross_entropy['s_on_l_new']) / float(train_batch_size) dot_product = cross_entropy['s_on_l_new'] - shadow # dot_product = tf.clip_by_value( # dot_product, # clip_value_min=-params.mpl_dot_product_bound, # clip_value_max=params.mpl_dot_product_bound) moving_dot_product = tf.get_variable( 'moving_dot_product', shape=[], trainable=False, dtype=tf.float32) moving_dot_product_update = tf.assign_sub( moving_dot_product, 0.01 * (moving_dot_product - dot_product)) with tf.control_dependencies([moving_dot_product_update]): dot_product = dot_product - moving_dot_product dot_product = tf.stop_gradient(dot_product) cross_entropy['mpl'] = tf.losses.softmax_cross_entropy( onehot_labels=tf.stop_gradient(tf.nn.softmax(logits['u_aug'], axis=-1)), logits=logits['u_aug'], reduction=tf.losses.Reduction.NONE) cross_entropy['mpl'] = tf.reduce_sum(cross_entropy['mpl']) / float( train_batch_size*uda_data) # teacher train op uda_weight = params.uda_weight * tf.minimum( 1., tf.cast(global_step, tf.float32) / float(params.uda_steps)) teacher_loss = (cross_entropy['u'] * uda_weight + cross_entropy['l'] + cross_entropy['mpl'] * dot_product) w_s['t'] = [w for w in tf.trainable_variables() if 'teacher' in w.name] g_s['t'] = tf.gradients(teacher_loss, w_s['t']) g_s['t'] = common_utils.add_weight_decay(params, w_s['t'], g_s['t']) g_s['t'], g_n['t'] = tf.clip_by_global_norm(g_s['t'], params.grad_bound) lr['t'] = common_utils.get_learning_rate( params, initial_lr=params.mpl_teacher_lr, num_warmup_steps=params.mpl_teacher_lr_warmup_steps) lr['t'], optim['t'] = common_utils.get_optimizer(params, learning_rate=lr['t']) teacher_train_op = optim['t'].apply_gradients(zip(g_s['t'], w_s['t']), global_step=global_step) with tf.control_dependencies([teacher_train_op]): logs = collections.OrderedDict() logs['global_step'] = tf.cast(global_step, tf.float32) logs['cross_entropy/student_on_u'] = cross_entropy['s_on_u'] logs['cross_entropy/student_on_l'] = (cross_entropy['s_on_l_new'] / num_replicas) logs['cross_entropy/teacher_on_u'] = cross_entropy['u'] logs['cross_entropy/teacher_on_l'] = cross_entropy['l'] logs['lr/student'] = tf.identity(lr['s']) / num_replicas logs['lr/teacher'] = tf.identity(lr['t']) / num_replicas logs['mpl/dot_product'] = dot_product / num_replicas logs['mpl/moving_dot_product'] = moving_dot_product / num_replicas logs['uda/u_ratio'] = tf.reduce_mean(masks['u']) / num_replicas logs['uda/l_ratio'] = tf.reduce_mean(masks['l']) / num_replicas logs['uda/weight'] = uda_weight / num_replicas tensors = [tf.expand_dims(t, axis=0) for t in logs.values()] self.step_info = {k: [tf.float32, [1]] for k in logs.keys()} def outfeed(tensors): with tf.device(tf.tpu.core(params.num_cores_per_replica-1)): return tf.raw_ops.OutfeedEnqueueTuple(inputs=tensors) outfeed_enqueue_op = tf.cond( common_utils.should_log(params), lambda: outfeed(tensors), tf.no_op) return outfeed_enqueue_op
def step_fn(self, params, model): """Separate implementation.""" train_batch_size = params.train_batch_size num_replicas = params.num_replicas batch_size = train_batch_size // num_replicas dtypes = [ tf.bfloat16 if params.use_bfloat16 else tf.float32, tf.float32, tf.bfloat16 if params.use_bfloat16 else tf.float32, tf.bfloat16 if params.use_bfloat16 else tf.float32] shapes = [ [batch_size, params.image_size, params.image_size, 3], [batch_size, params.num_classes], [batch_size*params.uda_data, params.image_size, params.image_size, 3], [batch_size*params.uda_data, params.image_size, params.image_size, 3]] if params.use_xla_sharding and params.num_cores_per_replica > 1: q = tpu_feed._PartitionedInfeedQueue( number_of_tuple_elements=4, host_id=0, input_partition_dims=[[1, 1, params.num_cores_per_replica, 1], [1, 1], [1, 1, params.num_cores_per_replica, 1], [1, 1, params.num_cores_per_replica, 1],], device_assignment=params.device_assignment) q.set_tuple_types(dtypes) q.set_tuple_shapes(shapes) l_images, l_labels, u_images_ori, u_images_aug = q.generate_dequeue_op() l_images = xla_sharding.split(l_images, 2, params.num_cores_per_replica) u_images_ori = xla_sharding.split(u_images_ori, 2, params.num_cores_per_replica) u_images_aug = xla_sharding.split(u_images_aug, 2, params.num_cores_per_replica) else: with tf.device(tf.tpu.core(0)): (l_images, l_labels, u_images_ori, u_images_aug) = tf.raw_ops.InfeedDequeueTuple(dtypes=dtypes, shapes=shapes) all_images = tf.concat([l_images, u_images_ori, u_images_aug], axis=0) global_step = tf.train.get_or_create_global_step() num_replicas = tf.cast(params.num_replicas, tf.float32) with tf.variable_scope(MODEL_SCOPE, reuse=tf.AUTO_REUSE): _, _, masks, cross_entropy = UDA.build_uda_cross_entropy( params, model, all_images, l_labels) l2_reg_rate = tf.cast(params.weight_decay / params.num_replicas, tf.float32) weight_dec = common_utils.get_l2_loss() uda_weight = params.uda_weight * tf.minimum( 1., tf.cast(global_step, tf.float32) / float(params.uda_steps)) total_loss = (cross_entropy['u'] * uda_weight + cross_entropy['l'] + weight_dec * l2_reg_rate) variables = tf.trainable_variables() gradients = tf.gradients(total_loss, variables) gradients = [tf.tpu.cross_replica_sum(g) for g in gradients] gradients, grad_norm = tf.clip_by_global_norm(gradients, params.grad_bound) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) learning_rate, optimizer = common_utils.get_optimizer(params) with tf.control_dependencies(update_ops): train_op = optimizer.apply_gradients(zip(gradients, variables), global_step=global_step) with tf.control_dependencies([train_op]): ema_train_op = common_utils.setup_ema( params, f'{MODEL_SCOPE}/{model.name}') with tf.control_dependencies([ema_train_op]): logs = collections.OrderedDict() logs['global_step'] = tf.cast(global_step, tf.float32) logs['loss/total'] = total_loss logs['loss/cross_entropy'] = cross_entropy['l'] logs['loss/lr'] = tf.identity(learning_rate) / num_replicas logs['loss/grad_norm'] = tf.identity(grad_norm) / num_replicas logs['loss/weight_dec'] = weight_dec / num_replicas logs['uda/cross_entropy'] = cross_entropy['u'] logs['uda/u_ratio'] = tf.reduce_mean(masks['u']) / num_replicas logs['uda/l_ratio'] = tf.reduce_mean(masks['l']) / num_replicas logs['uda/weight'] = uda_weight / num_replicas tensors = [tf.expand_dims(t, axis=0) for t in logs.values()] self.step_info = {k: [tf.float32, [1]] for k in logs.keys()} outfeed_enqueue_op = tf.cond( common_utils.should_log(params), lambda: tf.raw_ops.OutfeedEnqueueTuple(inputs=tensors), tf.no_op) return outfeed_enqueue_op
def step_fn(self, params, model): """A single step for supervised learning.""" batch_size = params.train_batch_size // params.num_replicas dtypes = [tf.bfloat16 if params.use_bfloat16 else tf.float32, tf.float32] shapes = [[batch_size, params.image_size, params.image_size, 3], [batch_size, params.num_classes]] if params.use_xla_sharding and params.num_cores_per_replica > 1: q = tpu_feed._PartitionedInfeedQueue( number_of_tuple_elements=2, host_id=0, input_partition_dims=[[1, 1, params.num_cores_per_replica, 1], [1, 1]], device_assignment=params.device_assignment) q.set_tuple_types(dtypes) q.set_tuple_shapes(shapes) images, labels = q.generate_dequeue_op() images = xla_sharding.split(images, 2, params.num_cores_per_replica) else: with tf.device(tf.tpu.core(0)): images, labels = tf.raw_ops.InfeedDequeueTuple(dtypes=dtypes, shapes=shapes) if labels.dtype == tf.int32: labels = tf.one_hot(labels, depth=params.num_classes, dtype=tf.float32) global_step = tf.train.get_or_create_global_step() train_batch_size = tf.cast(params.train_batch_size, tf.float32) num_replicas = tf.cast(params.num_replicas, tf.float32) with tf.variable_scope(MODEL_SCOPE): logits = model(images, training=True) if 'noisy_student' in params.dataset_name.lower(): cross_entropy = labels * tf.nn.log_softmax(logits, axis=-1) cross_entropy = tf.reduce_sum(-cross_entropy) / train_batch_size else: cross_entropy = tf.losses.softmax_cross_entropy( onehot_labels=labels, logits=logits, label_smoothing=params.label_smoothing, reduction=tf.losses.Reduction.SUM) / train_batch_size l2_reg_rate = tf.cast(params.weight_decay / params.num_replicas, tf.float32) weight_dec = common_utils.get_l2_loss() total_loss = cross_entropy + weight_dec * l2_reg_rate variables = tf.trainable_variables() gradients = tf.gradients(total_loss, variables) gradients = [tf.tpu.cross_replica_sum(g) for g in gradients] gradients, grad_norm = tf.clip_by_global_norm(gradients, params.grad_bound) learning_rate, optimizer = common_utils.get_optimizer(params) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) train_op = tf.cond( tf.math.is_finite(grad_norm), lambda: optimizer.apply_gradients(zip(gradients, variables), global_step=global_step), tf.no_op) with tf.control_dependencies(update_ops + [train_op]): ema_train_op = common_utils.setup_ema(params, f'{MODEL_SCOPE}/{model.name}') with tf.control_dependencies([ema_train_op]): logs = collections.OrderedDict() logs['global_step'] = tf.cast(global_step, tf.float32) logs['loss/total'] = total_loss logs['loss/weight_decay'] = weight_dec / num_replicas logs['loss/cross_entropy'] = cross_entropy logs['loss/lr'] = tf.identity(learning_rate) / num_replicas logs['loss/grad_norm'] = grad_norm / num_replicas tensors = [tf.expand_dims(t, axis=0) for t in logs.values()] self.step_info = {k: [tf.float32, [1]] for k in logs.keys()} outfeed_enqueue_op = tf.cond( common_utils.should_log(params), lambda: tf.raw_ops.OutfeedEnqueueTuple(inputs=tensors), tf.no_op) return outfeed_enqueue_op
def enqueue_ops_fn(idx): """Generate the infeed enqueue ops graph.""" per_host_sharded_inputs = [] control_deps = [] for _ in range(FLAGS.replicas_per_host): with tf.control_dependencies(control_deps): self.feature_structure[ is_training] = iterator.get_next() self.maybe_capture_embedding_inputs( self.feature_structure[is_training], is_training) flattened_inputs = tf.nest.flatten( self.feature_structure[is_training]) control_deps.extend(flattened_inputs) if input_partition_dims: padded_inputs = [] for inp in flattened_inputs: if inp.shape.ndims < len(input_partition_dims): padded_inputs.append(inp) continue paddings = [] for i, j in enumerate(input_partition_dims): r = inp.shape.as_list()[i] % j if r > 0: paddings.append([0, j - r]) else: paddings.append([0, 0]) for i in range(inp.shape.ndims - len(input_partition_dims)): paddings.append([0, 0]) padded_inputs.append(tf.pad(inp, paddings)) per_host_sharded_inputs.append(padded_inputs) else: per_host_sharded_inputs.append(flattened_inputs) if input_partition_dims: flattened_input_dims = [] for i in per_host_sharded_inputs[0]: if i.shape.ndims == len(input_partition_dims): flattened_input_dims.append( input_partition_dims) elif i.shape.ndims > len(input_partition_dims): flattened_input_dims.append( input_partition_dims + [1] * (i.shape.ndims - len(input_partition_dims)) ) else: flattened_input_dims.append([1] * i.shape.ndims) # pylint: disable=protected-access self.infeed_op[ is_training] = tpu_feed._PartitionedInfeedQueue( number_of_tuple_elements=len( per_host_sharded_inputs[0]), host_id=host_id, input_partition_dims=flattened_input_dims, device_assignment=self.device_assignment) with tf.control_dependencies( self.infeed_op[is_training]. generate_enqueue_ops(per_host_sharded_inputs)): return idx + 1 else: self.infeed_op[is_training] = tpu_feed.InfeedQueue( number_of_tuple_elements=len( per_host_sharded_inputs[0])) per_host_enqueue_ops = ( self.infeed_op[is_training].generate_enqueue_ops( per_host_sharded_inputs, tpu_ordinal_function=_tpu_ordinal_fn)) self.maybe_add_embedding_enqueue_ops_int( is_training, per_host_enqueue_ops) with tf.control_dependencies(per_host_enqueue_ops): return idx + 1