def _CreateLayerVariables(self): p = self.params w_pc = py_utils.WeightParams( shape=[self._ids_per_shard, p.embedding_dim], init=p.params_init, dtype=p.dtype, collections=[self.__class__.__name__ + '_vars']) embedding_table_vars = [] for i in range(p.num_tpu_hosts): device_name = self.GetDeviceName(i) with tf.device(device_name), py_utils.outside_all_rewrites(): var_name = self.GetVariableName(i) self.CreateVariable(var_name, w_pc) embedding_var = self.vars[var_name] embedding_table_vars.append(embedding_var) # Remove from _private_vars / _private_thetas to be added later as wm. del self._private_vars[var_name] del self._private_theta[var_name] if not py_utils.use_tpu(): # We don't want to add this for TrainerTpu, otherwise the identity # reference leads to copying the embedding to the TPU for no reason. # However, this is needed for CPU (eval/decode/controller). self._private_vars['wm'] = embedding_table_vars self._private_theta['wm'] = [ tf.identity(v) for v in embedding_table_vars ] # Only trainer and controller need slot variables and load/retrieve ops. if not self.do_eval: self._load_op_list, self._retrieve_op_list = ( self.optimizer.CreateSlotVariablesAndOps( embedding_table_vars, self))
def CreateSlotVariablesAndOps(self, table_vars, tpu_embedding_table): p = self.params load_op_list = [] retrieve_op_list = [] num_tpu_hosts = tpu_embedding_table.params.num_tpu_hosts table_name = tpu_embedding_table.table_name slot_var_collections = [tpu_embedding_table.__class__.__name__ + '_vars'] for host_id, table_var in zip(range(num_tpu_hosts), table_vars): # The slot vars should be on the same device as the table var. device_name = tpu_embedding_table.GetDeviceName(host_id) with tf.device(device_name), py_utils.outside_all_rewrites(): w_ada = py_utils.WeightParams( shape=table_var.shape.as_list(), init=py_utils.WeightInit.Constant(p.initial_accumulator), dtype=p.dtype, collections=slot_var_collections) var_name = tpu_embedding_table.GetVariableName(host_id) + '/Adagrad' tpu_embedding_table.CreateVariable(var_name, w_ada, trainable=False) accumulator_var = tpu_embedding_table.vars[var_name] # Only the Trainer needs these ops. if py_utils.use_tpu(): # Remove the slot vars from the variable list to void copying them # to TPU (by the tf.cast in tpu_embedding_table.theta). # pylint: disable=protected-access del tpu_embedding_table._private_vars[var_name] del tpu_embedding_table._private_theta[var_name] # pylint: enable=protected-access # TPU Embedding load/retrieve ops need to be in the outer graph scope. with tf.init_scope(): tf.logging.info('creating load and retrieve ops.') load_parameters_op = ( tpu_embedding_lib.tpu_ops.load_tpu_embedding_adagrad_parameters( parameters=table_var, accumulators=accumulator_var, table_name=table_name, num_shards=num_tpu_hosts, shard_id=host_id)) load_op_list.append(load_parameters_op) retrieved_table, retrieved_accumulator = ( tpu_embedding_lib.tpu_ops .retrieve_tpu_embedding_adagrad_parameters( table_name=table_name, num_shards=num_tpu_hosts, shard_id=host_id)) retrieve_parameters_op = tpu_embedding_lib.control_flow_ops.group( tf.assign(table_var, retrieved_table), tf.assign(accumulator_var, retrieved_accumulator)) retrieve_op_list.append(retrieve_parameters_op) return load_op_list, retrieve_op_list
def _CreateLayerVariables(self): p = self.params # Reuse the singleton table variables if they were created before. all_table_vars = self._tpu_embedding_collection.table_variables if self.table_name in all_table_vars: embedding_table_vars = all_table_vars[self.table_name] else: w_pc = py_utils.WeightParams( shape=[self._ids_per_shard, p.embedding_dim], init=p.params_init, dtype=p.dtype, collections=[self.__class__.__name__ + '_vars']) embedding_table_vars = [] for i in range(p.num_tpu_hosts): device_name = self.GetDeviceName(i) with tf.device(device_name), py_utils.outside_all_rewrites(): var_name = self.GetVariableName(i) self.CreateVariable(var_name, w_pc) embedding_var = self.vars[var_name] embedding_table_vars.append(embedding_var) # Remove from _private_vars / _private_thetas to be added later as wm. _RemovePrivateVar(self, var_name) self._tpu_embedding_collection.AddTableVariables(self.table_name, embedding_table_vars) if not _ShouldUseTpu(p): # We don't want to add this for TrainerTpu, otherwise the identity # reference leads to copying the embedding to the TPU for no reason. # However, this is needed for CPU (eval/decode/controller). self._private_vars['wm'] = embedding_table_vars self._private_theta['wm'] = [tf.identity(v) for v in embedding_table_vars] # If slot variables and load/retrieve ops were created before, maybe by a # different program or task, don't create it again. # Note that there should be only one copy of slot variables and # load/retrieve ops in the graph and they're shared by different # tasks/programs. all_load_ops = self._tpu_embedding_collection.load_ops if self.table_name not in all_load_ops: assert self.table_name not in self._tpu_embedding_collection.retrieve_ops # Only trainer and controller (for checkpointing) need slot variables. # Only trainer needs load/retrieve ops. if not self.do_eval and not p.is_inference: load_ops, retrieve_ops = self.optimizer.CreateSlotVariablesAndOps( embedding_table_vars, self) self._tpu_embedding_collection.AddLoadRetrieveOps( self.table_name, load_ops, retrieve_ops)
def CreateSlotVariablesAndOps(self, table_vars, tpu_embedding_table): load_op_list = [] retrieve_op_list = [] num_tpu_hosts = tpu_embedding_table.params.num_tpu_hosts table_name = tpu_embedding_table.table_name for host_id, table_var in zip(range(num_tpu_hosts), table_vars): # The slot vars should be on the same device as the table var. device_name = tpu_embedding_table.GetDeviceName(host_id) with tf.device(device_name), py_utils.outside_all_rewrites(): # Only the Trainer needs these ops. if py_utils.use_tpu(): # TPU Embedding load/retrieve ops need to be in the outer graph scope. with tf.init_scope(): tf.logging.info('creating load and retrieve ops.') load_parameters_op = ( tpu_embedding_lib.tpu_ops. load_tpu_embedding_stochastic_gradient_descent_parameters( parameters=table_var, table_name=table_name, num_shards=num_tpu_hosts, shard_id=host_id)) load_op_list.append(load_parameters_op) retrieved_table = ( tpu_embedding_lib.tpu_ops. retrieve_tpu_embedding_stochastic_gradient_descent_parameters( table_name=table_name, num_shards=num_tpu_hosts, shard_id=host_id)) retrieve_parameters_op = tpu_embedding_lib.control_flow_ops.group( tf.assign(table_var, retrieved_table)) retrieve_op_list.append(retrieve_parameters_op) return load_op_list, retrieve_op_list
def CreateTpuFeeds(self): """Creates the TPU infeed queue from preprocessed batch.""" p = self.params cluster = self.cluster num_tpu_hosts = cluster.num_tpu_hosts num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts tf.logging.info( 'CreateTPUFeeds num_splits_per_client={} ' 'num_devices_per_split={} num_tpu_hosts={} use_per_host_infeed={}'. format(cluster.num_splits_per_client, cluster.num_devices_per_split, num_tpu_hosts, p.use_per_host_infeed)) assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts) if (cluster.num_devices_per_split > num_cores_per_host and p.use_per_host_infeed): tf.logging.fatal( 'Doesn\'t support per host infeed mode when ' 'num_devices_per_split({}) > num_cores_per_host({})'.format( cluster.num_devices_per_split, num_cores_per_host)) num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1 with py_utils.outside_all_rewrites(): assert py_utils.use_tpu() assert not self._made_tpu_infeed shards = tpu_function.get_tpu_context( ).number_of_shards // num_infeed_hosts tf.logging.info('shards {}'.format(shards)) input_ops_list = [] queues = [] tpu_embedding_collection = tf.get_collection( py_utils.TPU_EMBEDDING) tpu_embedding = (tpu_embedding_collection[0] if tpu_embedding_collection else None) if num_tpu_hosts > 1 and tpu_embedding is not None: if not p.use_per_host_infeed: tf.logging.fatal( 'TPU Embedding must be used with per_host_infeed with multiple ' 'TPU host topologies.') tpu_emb_input_keys = (list( tpu_embedding.feature_to_config_dict.keys()) if tpu_embedding is not None else []) tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys) batch = None for task_id in range(num_infeed_hosts): host_device = '/task:{}/device:CPU:0'.format(task_id) with tf.device(host_device): batch = self.GetPreprocessedInputBatch() if isinstance(batch, py_utils.NestedMap): # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU. # Note that when MultiTaskData is used, bucket_keys will be at the # second level of the dictionary. batch = batch.FilterKeyVal( lambda k, _: not k.endswith('bucket_keys')) tf.logging.info('host_device: %s, batch: %r', host_device, batch) if tpu_embedding is not None: enqueue_dict_per_core = [ {} for _ in range(tpu_embedding.num_cores_per_host) ] num_cores_per_host = tpu_embedding.num_cores_per_host for key in tpu_emb_input_keys: feat = batch[key] tpu_emb_feat_splitted = tf.split( feat, num_cores_per_host) for core, split in enumerate( tpu_emb_feat_splitted): # Dense to sparse. Note the assumption of a padding id. sample_indices = tf.where( tf.not_equal(split, -1)) embedding_indices = tf.gather_nd( split, sample_indices) enqueue_data = tpu_embedding_lib.EnqueueData( embedding_indices, sample_indices) enqueue_dict_per_core[core][key] = enqueue_data input_ops_list += tpu_embedding.generate_enqueue_ops( enqueue_dict_per_core) for k, x in batch.FlattenItems(): assert x.shape.is_fully_defined(), ( 'Shape must be fully defined: %s: %s' % (k, x)) # TODO(cwhipkey): if it's a string (or other type not supported on # TPU), drop it from feeding and on the other end add in an op that # fails if used. shapes = batch.Transform(lambda x: x.shape).Flatten() dtypes = batch.Transform(lambda x: x.dtype).Flatten() tf.logging.info('host_device: %s infeed shapes: %r', host_device, shapes) tf.logging.info('host_device: %s infeed dtypes: %r', host_device, dtypes) if p.use_partitioned_infeed_queue: device_assignment = py_utils.GetTpuDeviceAssignment() host_device = device_assignment.host_device( replica=0, job=tf.flags.FLAGS.tf_master) host_id = int( host_device.split('/task:')[1].split('/device:') [0]) tf.logging.info('host_id: {} host_device: {}'.format( host_id, host_device)) q = tpu_feed._PartitionedInfeedQueue( # pylint: disable=protected-access number_of_tuple_elements=len(dtypes), device_assignment=device_assignment, host_id=host_id, input_partition_dims=[[p.num_partitions, 1] for _ in dtypes], tuple_types=dtypes, tuple_shapes=shapes) else: q = tpu_feed.InfeedQueue(tuple_types=dtypes, tuple_shapes=shapes) assert shards is not None q.set_number_of_shards(shards) queues.append(q) tf.logging.info('q=%r', q) if p.use_partitioned_infeed_queue: input_ops = q.generate_enqueue_ops([batch.Flatten()]) elif p.use_per_host_infeed: # TODO(ylc/zhifengc): Add this to a policy module and test it. def TPUOrdinalFunction(shard_index_in_host): device_assignment = py_utils.GetTpuDeviceAssignment( ) if device_assignment: # We put both enqueue/dequeue ops at core 0 in each replica. replica = device_assignment.lookup_replicas( task_id, 0)[shard_index_in_host] # pylint: disable=cell-var-from-loop return device_assignment.tpu_ordinal( replica=replica) else: return shard_index_in_host input_ops = q.split_inputs_and_generate_enqueue_ops( batch.Flatten(), placement_function=lambda x: host_device, # pylint: disable=cell-var-from-loop tpu_ordinal_function=TPUOrdinalFunction) else: input_ops = q.split_inputs_and_generate_enqueue_ops( batch.Flatten(), device_assignment=py_utils.GetTpuDeviceAssignment( )) input_ops_list += input_ops tf.logging.info('input_ops_list %s', input_ops_list) tpu_infeed_op = tf.group(*input_ops_list) self._made_tpu_infeed = True # Let trainer.py use multiple threads to drive the infeed op. for _ in range(p.tpu_infeed_parallelism): tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op) self._tpu_infeed_op = tpu_infeed_op with tf.device(tf.tpu.core(0)): tensors = queues[0].generate_dequeue_op() return batch.Pack(tensors)
def __init__(self, params): assert issubclass(params.cls, BaseTask) super(BaseTask, self).__init__(params) p = self.params if p.input: # TODO(zhifengc): Consider a simpler way to ensure the input # generator stops after one epoch. if p.is_eval and p.eval: seq_inp = issubclass(p.input.cls, base_input_generator.BaseInputGeneratorFromFiles) if p.input.num_samples == 0: # Dataset size is unknown. Computes eval summary based on num_samples. assert p.eval.samples_per_summary > 0 elif (p.eval.samples_per_summary == 0) or (p.input.num_samples < p.eval.samples_per_summary): # If we know the dataset size and we want to evaluate the full # set, we need to coordinate the input generator to flush out # all samples so the evaler and decoder compute metrics on the # whole set for each summary step. if seq_inp: p.input.flush_every_n = p.input.num_samples p.eval.samples_per_summary = p.input.num_samples if seq_inp and p.input.num_batcher_threads > 1: tf.logging.warning('input.num_batcher_threads > 1 inside eval mode. ' 'The input generator may not iterate over exactly ' 'one epoch per run') with tf.device( self.cluster.input_device), py_utils.outside_all_rewrites(): self.CreateChild('input', p.input) self._var_grads = None self._encoder = None self._online_encoder = None self._decoder = None self._total_examples = None self._total_nans_and_infs = None self._loss = None self._num_predictions = None self._train_op = None self._eval_metrics = {} self._trainer_verbose_tensors = {} # Create the gradient mask, self._per_input_gradient_mask = None self._shared_global_step = py_utils.GetOrCreateGlobalStep() tp = p.train if tp: if tp.task_global_step: self._task_global_step = CreateTaskGlobalStep(p, p.name) self._global_step = self._task_global_step else: self._task_global_step = None self._global_step = self._shared_global_step if tp.grad_norm_tracker: with tf.variable_scope(p.name): self.CreateChild('grad_norm_tracker', tp.grad_norm_tracker) self.CreateChild('lr_schedule', tp.lr_schedule) self.CreateChild('optimizer', tp.optimizer) self._UpdateVnConfig()
def CreateTpuFeeds(self): """Creates the TPU infeed queue from preprocessed batch.""" p = self.params cluster = cluster_factory.Current() num_tpu_hosts = cluster.num_tpu_hosts assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts) num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1 with py_utils.outside_all_rewrites(): assert py_utils.use_tpu() assert not self._made_tpu_infeed shards = tpu_function.get_tpu_context( ).number_of_shards // num_infeed_hosts input_ops_list = [] queues = [] first_batch = None for task_id in range(num_infeed_hosts): host_device = '/task:{}/device:CPU:0'.format(task_id) with tf.device(host_device): batch = self.GetPreprocessedInputBatch() if first_batch is None: first_batch = batch flat_batch = batch.FlattenItems() shapes, types = [], [] for k, x in flat_batch: assert x.shape.is_fully_defined(), ( 'Shape must be fully defined: %s: %s' % (k, x)) # TODO(cwhipkey): if it's a string (or other type not supported on # TPU), drop it from feeding and on the other end add in an op that # fails if used. shapes.append(x.shape) types.append(x.dtype) q = tf.contrib.tpu.InfeedQueue(tuple_types=types, tuple_shapes=shapes) queues.append(q) assert shards is not None q.set_number_of_shards(shards) if p.use_per_host_infeed: # TODO(ylc/zhifengc): Add this to a policy module and test it. def _tpu_ordinal_function(shard_index_in_host): device_assignment = py_utils.GetTpuDeviceAssignment() if device_assignment: # We put both enqueue/dequeue ops at core 0 in each replica. replica = device_assignment.lookup_replicas( task_id, 0)[shard_index_in_host] # pylint: disable=cell-var-from-loop return device_assignment.tpu_ordinal(replica=replica) else: return shard_index_in_host input_ops = q.split_inputs_and_generate_enqueue_ops( [v for _, v in flat_batch], placement_function=lambda x: host_device, # pylint: disable=cell-var-from-loop tpu_ordinal_function=_tpu_ordinal_function) else: input_ops = q.split_inputs_and_generate_enqueue_ops( [v for _, v in flat_batch], device_assignment=py_utils.GetTpuDeviceAssignment()) input_ops_list += input_ops tf.logging.info('input_ops_list %s', input_ops_list) tpu_infeed_op = tf.group(*input_ops_list) self._made_tpu_infeed = True # Let trainer.py use multiple threads to drive the infeed op. for _ in range(p.tpu_infeed_parallism): tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op) with tf.device(tf.contrib.tpu.core(0)): tensors = queues[0].generate_dequeue_op() return first_batch.Pack(tensors)
def CreateTpuFeeds(self): """Creates the TPU infeed queue from preprocessed batch.""" p = self.params cluster = self.cluster num_tpu_hosts = cluster.num_tpu_hosts num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts tf.logging.info('num_cores_per_host {}'.format(num_cores_per_host)) tf.logging.info('num_devices_per_split {}'.format( cluster.num_devices_per_split)) assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts) if (cluster.num_devices_per_split > num_cores_per_host and p.use_per_host_infeed): tf.logging.fatal( 'Doesn\'t support per host infeed mode when ' 'num_devices_per_split({}) > num_cores_per_host({})'.format( cluster.num_devices_per_split, num_cores_per_host)) num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1 with py_utils.outside_all_rewrites(): assert py_utils.use_tpu() assert not self._made_tpu_infeed shards = tpu_function.get_tpu_context( ).number_of_shards // num_infeed_hosts input_ops_list = [] queues = [] first_batch = None tpu_embedding_collection = tf.get_collection( py_utils.TPU_EMBEDDING) tpu_embedding = (tpu_embedding_collection[0] if tpu_embedding_collection else None) tpu_embedding_input_keys = ( tpu_embedding.feature_to_config_dict.keys() if tpu_embedding is not None else []) for task_id in range(num_infeed_hosts): host_device = '/task:{}/device:CPU:0'.format(task_id) with tf.device(host_device): batch = self.GetPreprocessedInputBatch() tpu_embedding_features = [] for tpu_embedding_input_key in tpu_embedding_input_keys: tpu_embedding_feature = batch.pop( tpu_embedding_input_key) tpu_embedding_features.append( (tpu_embedding_input_key, tpu_embedding_feature)) if first_batch is None: first_batch = batch flat_batch = batch.FlattenItems() if tpu_embedding is not None: enqueue_dict_per_core = [ {} ] * tpu_embedding.num_cores_per_host num_cores_per_host = tpu_embedding.num_cores_per_host for tpu_embedding_input_key, tpu_embedding_feature in tpu_embedding_features: tpu_embedding_feature_splitted = tf.split( tpu_embedding_feature, num_cores_per_host) for core, split in enumerate( tpu_embedding_feature_splitted): enqueue_data = tpu_embedding_lib.EnqueueData( tf.squeeze(split, axis=[1])) enqueue_dict_per_core[core][ tpu_embedding_input_key] = enqueue_data input_ops_list += tpu_embedding.generate_enqueue_ops( enqueue_dict_per_core) shapes, types = [], [] for k, x in flat_batch: assert x.shape.is_fully_defined(), ( 'Shape must be fully defined: %s: %s' % (k, x)) # TODO(cwhipkey): if it's a string (or other type not supported on # TPU), drop it from feeding and on the other end add in an op that # fails if used. shapes.append(x.shape) types.append(x.dtype) q = tf.contrib.tpu.InfeedQueue(tuple_types=types, tuple_shapes=shapes) queues.append(q) assert shards is not None q.set_number_of_shards(shards) if p.use_per_host_infeed: # TODO(ylc/zhifengc): Add this to a policy module and test it. def _tpu_ordinal_function(shard_index_in_host): device_assignment = py_utils.GetTpuDeviceAssignment( ) if device_assignment: # We put both enqueue/dequeue ops at core 0 in each replica. replica = device_assignment.lookup_replicas( task_id, 0)[shard_index_in_host] # pylint: disable=cell-var-from-loop return device_assignment.tpu_ordinal( replica=replica) else: return shard_index_in_host input_ops = q.split_inputs_and_generate_enqueue_ops( [v for _, v in flat_batch], placement_function=lambda x: host_device, # pylint: disable=cell-var-from-loop tpu_ordinal_function=_tpu_ordinal_function) else: input_ops = q.split_inputs_and_generate_enqueue_ops( [v for _, v in flat_batch], device_assignment=py_utils.GetTpuDeviceAssignment( )) input_ops_list += input_ops tf.logging.info('input_ops_list %s', input_ops_list) tpu_infeed_op = tf.group(*input_ops_list) self._made_tpu_infeed = True # Let trainer.py use multiple threads to drive the infeed op. for _ in range(p.tpu_infeed_parallism): tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op) with tf.device(tf.compat.v1.tpu.core(0)): tensors = queues[0].generate_dequeue_op() return first_batch.Pack(tensors)
def __init__(self, params): assert issubclass(params.cls, BaseTask) # Ensure global_step exists before calling super. py_utils.GetOrCreateGlobalStepVar() super(BaseTask, self).__init__(params) p = self.params if p.input: # TODO(zhifengc): Consider a simpler way to ensure the input # generator stops after one epoch. if p.is_eval and p.eval: seq_inp = issubclass(p.input.cls, base_input_generator.BaseInputGeneratorFromFiles) if p.input.num_samples == 0: # Dataset size is unknown. Computes eval summary based on num_samples. assert p.eval.samples_per_summary > 0 elif (p.eval.samples_per_summary == 0) or (p.input.num_samples < p.eval.samples_per_summary): # If we know the dataset size and we want to evaluate the full # set, we need to coordinate the input generator to flush out # all samples so the evaler and decoder compute metrics on the # whole set for each summary step. if seq_inp: p.input.flush_every_n = p.input.num_samples p.eval.samples_per_summary = p.input.num_samples if seq_inp and p.input.num_batcher_threads > 1: tf.logging.warning('input.num_batcher_threads > 1 inside eval mode. ' 'The input generator may not iterate over exactly ' 'one epoch per run') tf.logging.info('input_params: %s', p.input) input_params = self.cluster.PlaceInput(p.input) with py_utils.outside_all_rewrites(): self.CreateChild('input', input_params) self._encoder = None self._online_encoder = None self._decoder = None self._loss = None self._num_predictions = None self._train_op = None self._eval_metrics = {} self._per_example = {} self._trainer_verbose_tensors = {} # Create the gradient mask, self._per_input_gradient_mask = None task_global_step_list = tf.get_collection('TASK_GLOBAL_STEP', '^%s_global_step' % p.name) if len(task_global_step_list) > 1: raise ValueError('Found multiple task_global_step for task %s' % p.name) self._global_step_var = ( task_global_step_list[0] if len(task_global_step_list) == 1 else py_utils.GetOrCreateGlobalStepVar()) self._global_step = tf.identity( self._global_step_var, name='global_step_tensor') tp = p.train # p.train can be None if this task is the teacher/student task in a # DistillationTask. if tp and self.cluster.job in ('worker', 'trainer', 'trainer_client', 'controller', 'executor_tpu'): self._SetLearnerFromLegacyParams(tp) if tp.learner is not None: if isinstance(tp.learner, (list, tuple)): self.CreateChildren('learners', tp.learner) else: self.CreateChildren('learners', [tp.learner]) self._UpdateVnConfig()
def CreateTpuFeeds(self): """Creates the TPU infeed queue from preprocessed batch.""" p = self.params cluster = self.cluster num_tpu_hosts = cluster.num_tpu_hosts num_cores_per_host = cluster.total_worker_devices // num_tpu_hosts tf.logging.info('num_cores_per_host {}'.format(num_cores_per_host)) tf.logging.info('num_devices_per_split {}'.format( cluster.num_devices_per_split)) assert num_tpu_hosts > 0, ('num_tpu_hosts: %d' % num_tpu_hosts) if (cluster.num_devices_per_split > num_cores_per_host and p.use_per_host_infeed): tf.logging.fatal( 'Doesn\'t support per host infeed mode when ' 'num_devices_per_split({}) > num_cores_per_host({})'.format( cluster.num_devices_per_split, num_cores_per_host)) num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1 with py_utils.outside_all_rewrites(): assert py_utils.use_tpu() assert not self._made_tpu_infeed shards = tpu_function.get_tpu_context( ).number_of_shards // num_infeed_hosts input_ops_list = [] queues = [] tpu_embedding_collection = tf.get_collection( py_utils.TPU_EMBEDDING) tpu_embedding = (tpu_embedding_collection[0] if tpu_embedding_collection else None) tpu_emb_input_keys = (list( tpu_embedding.feature_to_config_dict.keys()) if tpu_embedding is not None else []) tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys) batch = None for task_id in range(num_infeed_hosts): host_device = '/task:{}/device:CPU:0'.format(task_id) with tf.device(host_device): batch = self.GetPreprocessedInputBatch() if 'bucket_keys' in batch: # Hack: bucket_keys are not needed on TPU. del batch['bucket_keys'] tf.logging.info('host_device: %s, batch: %r', host_device, batch) if tpu_embedding is not None: enqueue_dict_per_core = [ {} for _ in range(tpu_embedding.num_cores_per_host) ] num_cores_per_host = tpu_embedding.num_cores_per_host for key in tpu_emb_input_keys: feat = batch[key] tpu_emb_feat_splitted = tf.split( feat, num_cores_per_host) for core, split in enumerate( tpu_emb_feat_splitted): # Dense to sparse. Note the assumption of a padding id. sample_indices = tf.where( tf.not_equal(split, -1)) embedding_indices = tf.gather_nd( split, sample_indices) enqueue_data = tpu_embedding_lib.EnqueueData( embedding_indices, sample_indices) enqueue_dict_per_core[core][key] = enqueue_data input_ops_list += tpu_embedding.generate_enqueue_ops( enqueue_dict_per_core) for k, x in batch.FlattenItems(): assert x.shape.is_fully_defined(), ( 'Shape must be fully defined: %s: %s' % (k, x)) # TODO(cwhipkey): if it's a string (or other type not supported on # TPU), drop it from feeding and on the other end add in an op that # fails if used. shapes = batch.Transform(lambda x: x.shape).Flatten() dtypes = batch.Transform(lambda x: x.dtype).Flatten() tf.logging.info('host_device: %s infeed shapes: %r', host_device, shapes) tf.logging.info('host_device: %s infeed dtypes: %r', host_device, dtypes) q = tpu_feed.InfeedQueue(tuple_types=dtypes, tuple_shapes=shapes) queues.append(q) assert shards is not None q.set_number_of_shards(shards) if p.use_per_host_infeed: # TODO(ylc/zhifengc): Add this to a policy module and test it. def TPUOrdinalFunction(shard_index_in_host): device_assignment = py_utils.GetTpuDeviceAssignment( ) if device_assignment: # We put both enqueue/dequeue ops at core 0 in each replica. replica = device_assignment.lookup_replicas( task_id, 0)[shard_index_in_host] # pylint: disable=cell-var-from-loop return device_assignment.tpu_ordinal( replica=replica) else: return shard_index_in_host input_ops = q.split_inputs_and_generate_enqueue_ops( batch.Flatten(), placement_function=lambda x: host_device, # pylint: disable=cell-var-from-loop tpu_ordinal_function=TPUOrdinalFunction) else: input_ops = q.split_inputs_and_generate_enqueue_ops( batch.Flatten(), device_assignment=py_utils.GetTpuDeviceAssignment( )) input_ops_list += input_ops tf.logging.info('input_ops_list %s', input_ops_list) tpu_infeed_op = tf.group(*input_ops_list) self._made_tpu_infeed = True # Let trainer.py use multiple threads to drive the infeed op. for _ in range(p.tpu_infeed_parallelism): tf.add_to_collection(py_utils.ENQUEUE_OPS, tpu_infeed_op) # For executor-driven multiple programs, we need more fine-grained # access rather than using a single global graph collection. self.tpu_infeed_op = tpu_infeed_op with tf.device(tf.tpu.core(0)): tensors = queues[0].generate_dequeue_op() return batch.Pack(tensors)
def CreateSlotVariablesAndOps(self, table_vars, tpu_embedding_table): p = self.params load_op_list = [] retrieve_op_list = [] num_tpu_hosts = tpu_embedding_table.params.num_tpu_hosts table_name = tpu_embedding_table.table_name slot_var_collections = [tpu_embedding_table.__class__.__name__ + '_vars'] for host_id, table_var in zip(range(num_tpu_hosts), table_vars): # The slot vars should be on the same device as the table var. device_name = tpu_embedding_table.GetDeviceName(host_id) with tf.device(device_name), py_utils.outside_all_rewrites(): accumulator = py_utils.WeightParams( shape=table_var.shape.as_list(), init=py_utils.WeightInit.Constant(p.initial_accumulator_value), dtype=p.dtype, collections=slot_var_collections) accumulator_name = ( tpu_embedding_table.GetVariableName(host_id) + '/Ftrl') tpu_embedding_table.CreateVariable( accumulator_name, accumulator, trainable=False) accumulator_var = tpu_embedding_table.vars[accumulator_name] linear = py_utils.WeightParams( shape=table_var.shape.as_list(), init=py_utils.WeightInit.Constant(p.initial_linear_value), dtype=p.dtype, collections=slot_var_collections) linear_name = tpu_embedding_table.GetVariableName(host_id) + '/Ftrl_1' tpu_embedding_table.CreateVariable(linear_name, linear, trainable=False) linear_var = tpu_embedding_table.vars[linear_name] # Only the Trainer needs these ops. if py_utils.use_tpu(): # Remove the slot vars from the variable list to avoid them being # copied to TPU. _RemovePrivateVar(tpu_embedding_table, accumulator_name) _RemovePrivateVar(tpu_embedding_table, linear_name) # TPU Embedding load/retrieve ops need to be in the outer graph scope. with tf.init_scope(): tf.logging.info('creating load and retrieve ops.') load_parameters_op = ( tpu_embedding_lib.tpu_ops.load_tpu_embedding_ftrl_parameters( parameters=table_var, accumulators=accumulator_var, linears=linear_var, table_name=table_name, num_shards=num_tpu_hosts, shard_id=host_id)) load_op_list.append(load_parameters_op) retrieved_table, retrieved_accumulator, retrieved_linear = ( tpu_embedding_lib.tpu_ops .retrieve_tpu_embedding_ftrl_parameters( table_name=table_name, num_shards=num_tpu_hosts, shard_id=host_id)) retrieve_parameters_op = tpu_embedding_lib.control_flow_ops.group( tf.assign(table_var, retrieved_table), tf.assign(accumulator_var, retrieved_accumulator), tf.assign(linear_var, retrieved_linear)) retrieve_op_list.append(retrieve_parameters_op) return load_op_list, retrieve_op_list
def CreateSlotVariablesAndOps(self, table_vars, tpu_embedding_table): p = self.params load_op_list = [] retrieve_op_list = [] num_tpu_hosts = tpu_embedding_table.params.num_tpu_hosts table_name = tpu_embedding_table.table_name slot_var_collections = [tpu_embedding_table.__class__.__name__ + '_vars'] for host_id, table_var in zip(range(num_tpu_hosts), table_vars): # The slot vars should be on the same device as the table var. device_name = tpu_embedding_table.GetDeviceName(host_id) with tf.device(device_name), py_utils.outside_all_rewrites(): m_adam = py_utils.WeightParams( shape=table_var.shape.as_list(), init=py_utils.WeightInit.Constant(0.0), dtype=p.dtype, collections=slot_var_collections) var_name_m = tpu_embedding_table.GetVariableName(host_id) + '/Adam/m' tpu_embedding_table.CreateVariable(var_name_m, m_adam, trainable=False) m_var = tpu_embedding_table.vars[var_name_m] v_adam = py_utils.WeightParams( shape=table_var.shape.as_list(), init=py_utils.WeightInit.Constant(0.0), dtype=p.dtype, collections=slot_var_collections) var_name_v = tpu_embedding_table.GetVariableName(host_id) + '/Adam/v' tpu_embedding_table.CreateVariable(var_name_v, v_adam, trainable=False) v_var = tpu_embedding_table.vars[var_name_v] # Only the Trainer needs these ops. if py_utils.use_tpu(): # Remove the slot vars from the variable list to avoid them being # copied to TPU. _RemovePrivateVar(tpu_embedding_table, var_name_m) _RemovePrivateVar(tpu_embedding_table, var_name_v) # TPU Embedding load/retrieve ops need to be in the outer graph scope. with tf.init_scope(): tf.logging.info('creating load and retrieve ops.') load_parameters_op = ( tpu_embedding_lib.tpu_ops.load_tpu_embedding_adam_parameters( parameters=table_var, momenta=m_var, velocities=v_var, table_name=table_name, num_shards=num_tpu_hosts, shard_id=host_id)) load_op_list.append(load_parameters_op) retrieved_table, retrieved_m, retrieved_v = ( tpu_embedding_lib.tpu_ops .retrieve_tpu_embedding_adam_parameters( table_name=table_name, num_shards=num_tpu_hosts, shard_id=host_id)) retrieve_parameters_op = tpu_embedding_lib.control_flow_ops.group( tf.assign(table_var, retrieved_table), tf.assign(m_var, retrieved_m), tf.assign(v_var, retrieved_v)) retrieve_op_list.append(retrieve_parameters_op) return load_op_list, retrieve_op_list