def _DecodeStep(): """Decode call to be compiled for TPU.""" input_batch = self._task.input.TpuDequeueBatch() metrics_dict = self._task.Decode(input_batch) self.metrics_nm = py_utils.NestedMap(metrics_dict) device = tpu.core(0) if self.spmd else '' with tf.device(device): outfeed_enqueue = tpu_ops.outfeed_enqueue_tuple( self.metrics_nm.Flatten()) return [outfeed_enqueue]
def CreateSlotVariablesAndOps(self, table_vars, tpu_embedding_table): p = self.params load_op_list = [] retrieve_op_list = [] num_tpu_hosts = tpu_embedding_table.params.num_tpu_hosts table_name = tpu_embedding_table.table_name slot_var_collections = [ tpu_embedding_table.__class__.__name__ + '_vars' ] for host_id, table_var in zip(range(num_tpu_hosts), table_vars): # The slot vars should be on the same device as the table var. device_name = tpu_embedding_table.GetDeviceName(host_id) with tf.device(device_name), py_utils.outside_all_rewrites(): w_ada = py_utils.WeightParams( shape=table_var.shape.as_list(), init=py_utils.WeightInit.Constant(p.initial_accumulator), dtype=p.dtype, collections=slot_var_collections) var_name = tpu_embedding_table.GetVariableName(host_id) tpu_embedding_table.CreateVariable('%s/Adagrad' % var_name, w_ada, trainable=False) accumulator_var = tpu_embedding_table.vars['%s/Adagrad' % var_name] # Only the Trainer needs these ops. if py_utils.use_tpu(): # TPU Embedding load/retrieve ops need to be in the outer graph scope. with tf.init_scope(): tf.logging.info('creating load and retrieve ops.') load_parameters_op = ( tpu_embedding_lib.tpu_ops. load_tpu_embedding_adagrad_parameters( parameters=table_var, accumulators=accumulator_var, table_name=table_name, num_shards=num_tpu_hosts, shard_id=host_id)) load_op_list.append(load_parameters_op) retrieved_table, retrieved_accumulator = ( tpu_embedding_lib.tpu_ops. retrieve_tpu_embedding_adagrad_parameters( table_name=table_name, num_shards=num_tpu_hosts, shard_id=host_id)) retrieve_parameters_op = tpu_embedding_lib.control_flow_ops.group( tf.assign(table_var, retrieved_table), tf.assign(accumulator_var, retrieved_accumulator)) retrieve_op_list.append(retrieve_parameters_op) return load_op_list, retrieve_op_list
def __init__(self, inference_graph, subgraph_name=None, checkpoint=None, device_type="gpu", tf_master="", session_config=None, clear_device_placement=False): assert device_type in ["cpu", "gpu", "tpu"] subgraph_name = subgraph_name or "default" if isinstance(inference_graph, six.string_types): tf.logging.info("Reading inference graph from %s.", inference_graph) inference_graph = LoadInferenceGraph(inference_graph, clear_device_placement) self._inference_graph = inference_graph self._checkpoint = checkpoint self._device_type = device_type self._tf_master = tf_master self._session_config = session_config self._graph = tf.Graph() with self._graph.as_default(): tf.logging.info( "Loading inference graph for prediction subgraph_name={}.". format(subgraph_name)) self._saver = tf.train.Saver(saver_def=inference_graph.saver_def) with tf.device("/%s:0" % "cpu" if device_type == "tpu" else device_type): tf.import_graph_def(inference_graph.graph_def, name="") if device_type == "tpu": # If no tpu init op exists, create it here. try: self._graph.get_operation_by_name("tpu_init_op") except KeyError: tf.group(tf.tpu.initialize_system(), name="tpu_init_op") self._graph.finalize() if inference_graph.subgraphs: if subgraph_name not in inference_graph.subgraphs: raise ValueError( "Subgraph %s not defined. Valid subgraphs: %s" % (subgraph_name, list(inference_graph.subgraphs.keys()))) subgraph = inference_graph.subgraphs[subgraph_name] self._fetches = subgraph.fetches self._feeds = subgraph.feeds else: self._fetches = inference_graph.fetches self._feeds = inference_graph.feeds # Lock for creating new sessions. self._sess_lock = threading.Lock() self._cur_sess_id = 0 self._CreateNewSession()
def decode_fn(*infeed_batch): # pylint: disable=missing-docstring # Length 6 is passed when there is no tgt_mask (e.g. decoding) and # length 7 is passed when there is a tgt_mask (e.g. fprop). self.outfeed = self._config_outfeed(xformer, infeed_batch) with tf.device(tf.tpu.core(0)): outfeed_op = tpu_ops.outfeed_enqueue_tuple( tf.nest.flatten(self.outfeed)) return [outfeed_op]
def _DecodeStep(): """Decode call to be compiled for TPU.""" with py_utils.OpportunisticVariableReuseScope(True): self._model.InstantiateVariables() input_batch = self._task.input.TpuDequeueBatch() metrics_dict = self._task.Decode(input_batch) self.metrics_nm = py_utils.NestedMap(metrics_dict) device = tpu.core(0) if self.spmd else '' with tf.device(device): outfeed_enqueue = tpu_ops.outfeed_enqueue_tuple( self.metrics_nm.Flatten()) return [outfeed_enqueue]
def ReplicatedGenericInput(processor, num_replicas, replica_device_fn, **kwargs): """Builds a replicated input pipeline. This is similar to GenericInput, except that the input processing can be distributed across devices and then concatenated at the current device. Args: processor: see comments for GenericInput. num_replicas: the number of input processing replicas. Usually set to number of infeed hosts. replica_device_fn: a int -> string function that takes the replica index in range [0, num_replicas) and returns a TF device string, e.g., lambda i: '/task:{}/device:CPU:0'.format(i) **kwargs: additional keyword args for x_ops.generic_input. Returns: A tuple of (outputs, bucket_keys): - outputs: a NestedMap or a list of tensors, similar to `processor`'s return, except every tensor will have an additional dimension 0 that represents the batch dimension. The batch size will be (num_replicas * bucket_batch_limit[...]), i.e., kwargs['bucket_batch_limit'] specifies the per-replica batch size. - bucket_keys: a tf.int32 vector. Raises: RuntimeError: If called in pure Eager/tf.function mode without `generic_input_v2_key` defined. """ if num_replicas > 1 and 'bucket_batch_limit' in kwargs: assert all(b == max(kwargs['bucket_batch_limit']) for b in kwargs['bucket_batch_limit']) replica_outputs = [] if py_utils.IsEagerMode(): current_key = kwargs.pop('generic_input_v2_key', None) if current_key is None: raise RuntimeError(_MISSING_KEY_ERR) for replica_i in range(num_replicas): # Blend `replica_i` into the key for _GENERIC_CACHE_V2 to distinguish # different GenericInputV2 ops in the same Datasource object. if py_utils.IsEagerMode(): kwargs['generic_input_v2_key'] = (current_key, replica_i) replica_device = replica_device_fn(replica_i) with tf.device(replica_device): replica_outputs.append(GenericInput(processor, **kwargs)) output_nmaps, output_bucket_keys = zip(*replica_outputs) concat_nmap = tf.nest.map_structure(lambda *t: tf.concat(t, axis=0), *output_nmaps) concat_bucket_keys = tf.concat(output_bucket_keys, axis=0) return concat_nmap, concat_bucket_keys
def testBasic(self): devices = _ListDevices(_Target()) print("\n".join(devices)) sender, recver = devices[0], devices[-1] shape = [] for dtype in tf.float32, tf.complex64: to_send = np.array(3.1415 + 2j).astype(dtype.as_numpy_dtype) g = tf.Graph() with g.as_default(): ch = sendrecv.Channel(dtype, shape, sender, recver, "test") with tf.device(sender): src_val = tf.constant(to_send) send_op = ch.Send(src_val) with tf.device(recver): recv_val = ch.Recv() with tf.Session(_Target(), graph=g): _, val = self.evaluate([send_op, recv_val]) self.assertAllClose(to_send, val)
def SendRecv(graph, dtype): to_send = np.array(3.1415 + 2j).astype(dtype.as_numpy_dtype) with graph.as_default(): ch = sendrecv.Channel(dtype, shape, sender, recver, "test") with tf.device(sender): @tf.Defun() def Send(): src_val = tf.constant(to_send) ch.Send(src_val) return 1.0 send_op = Send() with tf.device(recver): @tf.Defun() def Recv(): return ch.Recv() recv_val = Recv() return send_op, recv_val, to_send
def _CreateChildrenVariables(self): p = self.params num_cells = len(p.cell_tpl) before_tpl_device = '' cell_devices = [''] * num_cells if py_utils.use_tpu(): cluster = self.cluster before_tpl_device = cluster.WorkerDeviceInModelSplit(0) cell_devices = [ cluster.WorkerDeviceInModelSplit(i) for i in range(num_cells) ] for unused_name, l in self._before_layers: with tf.device(before_tpl_device): l.InstantiateVariables() for i, (unused_name, l) in enumerate(self._cells): with tf.device(cell_devices[i]): l.InstantiateVariables() super()._CreateChildrenVariables()
def SendRecv(graph, dtype): to_send = np.array(3.1415 + 2j).astype(dtype.as_numpy_dtype) with graph.as_default(): ch = sendrecv.Channel(dtype, shape, sender, recver, "test") with tf.device(sender): # py_utils.CallDefun requires non-empty inputs. Same below. def Send(_): src_val = tf.constant(to_send) ch.Send(src_val) return tf.convert_to_tensor(1.0) send_op = py_utils.CallDefun(Send, tf.convert_to_tensor(0)) with tf.device(recver): def Recv(_): return ch.Recv() recv_val = py_utils.CallDefun(Recv, tf.convert_to_tensor(0)) return send_op, recv_val, to_send
def _CreateLayerVariables(self): p = self.params # Reuse the singleton table variables if they were created before. all_table_vars = self._tpu_embedding_collection.table_variables if self.table_name in all_table_vars: embedding_table_vars = all_table_vars[self.table_name] else: w_pc = py_utils.WeightParams( shape=[self._ids_per_shard, p.embedding_dim], init=p.params_init, dtype=p.dtype, collections=[self.__class__.__name__ + '_vars']) embedding_table_vars = [] for i in range(p.num_tpu_hosts): device_name = self.GetDeviceName(i) with tf.device(device_name), py_utils.outside_all_rewrites(): var_name = self.GetVariableName(i) self.CreateVariable(var_name, w_pc) embedding_var = self.vars[var_name] embedding_table_vars.append(embedding_var) # Remove from _private_vars / _private_thetas to be added later as wm. _RemovePrivateVar(self, var_name) self._tpu_embedding_collection.AddTableVariables(self.table_name, embedding_table_vars) if not _ShouldUseTpu(p): # We don't want to add this for TrainerTpu, otherwise the identity # reference leads to copying the embedding to the TPU for no reason. # However, this is needed for CPU (eval/decode/controller). self._private_vars['wm'] = embedding_table_vars self._private_theta['wm'] = [tf.identity(v) for v in embedding_table_vars] # If slot variables and load/retrieve ops were created before, maybe by a # different program or task, don't create it again. # Note that there should be only one copy of slot variables and # load/retrieve ops in the graph and they're shared by different # tasks/programs. all_load_ops = self._tpu_embedding_collection.load_ops if self.table_name not in all_load_ops: assert self.table_name not in self._tpu_embedding_collection.retrieve_ops # Only trainer and controller (for checkpointing) need slot variables. # Only trainer needs load/retrieve ops. if not self.do_eval and not p.is_inference: load_ops, retrieve_ops = self.optimizer.CreateSlotVariablesAndOps( embedding_table_vars, self) self._tpu_embedding_collection.AddLoadRetrieveOps( self.table_name, load_ops, retrieve_ops)
def CreateVariable(self, name: str, var_params: hyperparams.Params, **kwargs) -> None: """Create a variable of this layer according to the parameter `var_params`. E.g.:: def __init__(self, ...): # A layer's constructor self.CreateVariable( 'weight', py_utils.WeightParams(shape=[100, 100])) Args: name: Variable name which is used as the key into vars/theta. var_params: `Params` used to create the variable. **kwargs: Keyword args passed to `.py_utils.CreateVariable`. """ kwargs.setdefault('default_seed', self.params.random_seed) if self.params.device_mesh is not None: if (len([dim for dim in var_params.shape if dim > 1]) > 1 and var_params.tensor_split_dims_mapping is None): tf.logging.warning( 'tensor_split_dims_mapping missing for %s.%s: shape=%s', self.path, name, var_params.shape) self._CheckName(name) if (self.params.skip_lp_regularization and py_utils.SKIP_LP_REGULARIZATION not in var_params.collections): var_params = py_utils.WeightParams( shape=var_params.shape, dtype=var_params.dtype, init=var_params.init, collections=(var_params.collections + [py_utils.SKIP_LP_REGULARIZATION])) self._var_symbolic_shape_map[name] = var_params.shape var = py_utils.CreateVariable(name, var_params, **kwargs) self._private_vars[name] = var if py_utils.IsEagerMode(): # With eager trainer, always use the variable directly. value = var else: if self.cluster.params.worker.gpus_per_replica > 0: # On GPU (which always trains a single step per session.run()), # reference a tensor in FProp to cache it on device and avoid extraneous # sends from reading variables from ps multiple times. with tf.device(var.device): value = tf.identity(var, name=name) else: value = var self._private_theta[name] = value
def CollectVarHistogram(vs_gs): """Adds histogram summaries for variables and gradients.""" for name, (var, grad) in vs_gs.FlattenItems(): with tf.device(var.device), tf.name_scope(name + '/summary'): if isinstance(grad, tf.IndexedSlices): var = tf.gather(var, grad.indices) grad = grad.values if var.dtype.is_complex: var = tf.abs(var) grad = tf.abs(grad) histogram('var_hist/' + name, var) histogram('grad_hist/' + name, grad)
def Recv(self): """Receives a tensor from the channel.""" if self._send_tpu_core == -1: return tf.raw_ops.Recv(tensor_type=self._dtype, tensor_name=self._name, send_device=self._send_device, send_device_incarnation=0, recv_device=self._recv_device) else: with tf.device(self._recv_device): return xla.recv(self._dtype, tensor_name=self._name, shape=self._shape, name="Recv_" + self._name)
def TpuDequeueBatch(self): """Create TPU dequeue ops. This should only be called within a TPU context. Returns: - A NestedMap of the input batch. """ assert self._tpu_queues, 'CreateTpuEnqueueOps must be called first.' with tf.device(tf.tpu.core(0)): # Note that the dequeue_tuple op on the TPU core # only cares about the shape/types being dequeued # which is why this is hard-coded to the first Queue. tensors = self._tpu_queues[0].generate_dequeue_op() return self._batch.Pack(tensors)
def __init__(self, params): super().__init__(params) p = self.params self._before_layers = [] self._cells = [] num_cells = len(p.cell_tpl) before_tpl_device = '' cell_devices = [''] * num_cells if py_utils.use_tpu(): cluster = self.cluster before_tpl_device = cluster.WorkerDeviceInModelSplit(0) cell_devices = [ cluster.WorkerDeviceInModelSplit(i) for i in range(num_cells) ] for l in p.before_tpl: with tf.device(before_tpl_device): self.CreateChild(l.name, l) self._before_layers.append((l.name, self.children[l.name])) for i, l in enumerate(p.cell_tpl): with tf.device(cell_devices[i]): self.CreateChild(l.name, l) self._cells.append((l.name, self.children[l.name]))
def _load_graph_from_inference_graph(self, inference_graph): """Returns a tf.Graph() constructed from `inference_graph`. Args: inference_graph: An InferenceGraph proto from which a graph_def is loaded from. Returns: A loaded tf.Graph(). """ graph = tf.Graph() with graph.as_default(): with tf.device("/%s:0" % "cpu" if self._device_type == "tpu" else self._device_type): tf.import_graph_def(inference_graph.graph_def, name="") return graph
def Send(self, tensor): """Sends a tensor through the channel.""" assert tensor.dtype == self._dtype assert not self._send_called, ("Send called multiple times for %s" % self._name) self._send_called = True if self._send_tpu_core == -1: return tf.raw_ops.Send( tensor=tensor, tensor_name=self._name, send_device=self._send_device, send_device_incarnation=0, recv_device=self._recv_device) else: with tf.device(self._send_device): return xla.send( tensor, tensor_name=self._name, name="Send_" + self._name)
def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._job_name = 'trainer' with self._graph.as_default(), tf.container(self._container_id): try: self._task_probs_summary_writers = [] for task in self._model.task_schedule.tasks: path = os.path.join(os.path.join(self._train_dir, task)) tf.io.gfile.makedirs(path) self._task_probs_summary_writers.append( self._CreateSummaryWriter(path)) except AttributeError: tf.logging.info( 'AttributeError. Expected for single task models.') self._task_probs_summary_writers = [] if self.params.cluster.task == 0: self._summary_writer = self._CreateSummaryWriter( self._train_dir) self._CreateTF2SummaryWriter(self._train_dir) else: self._summary_writer = None with self._cluster, tf.device( self._cluster.GetPlacer()), self._TF2SummaryContext(): self._model = self.params.Instantiate() self._params = self._model.params self._model.ConstructFPropBPropGraph() self._CreateTF2SummaryOps() self._initialize_tables = tf.tables_initializer() self._initialize_local_vars = tf.local_variables_initializer() self.enqueue_ops = tf.get_collection(py_utils.ENQUEUE_OPS) tf.logging.info('Trainer number of enqueue ops: %d', len(self.enqueue_ops)) self._step_rate_tracker = summary_utils.StepRateTracker() # Saves the graph def. if self.params.cluster.task == 0: self._WriteToLog(self.params.ToText(), self._train_dir, 'trainer_params.txt') tf.io.write_graph(self._graph.as_graph_def(), self._train_dir, 'train.pbtxt') worker_id = self.params.cluster.task self._start_up_delay_steps = (((worker_id + 1) * worker_id / 2) * self.params.train.start_up_delay_steps)
def CollectVarHistogram(vs_gs): """Adds histogram summaries for variables and gradients.""" for name, (var, grad) in vs_gs.FlattenItems(): name = py_utils.SanitizeScopeKey(name) with tf.device(var.device), tf.name_scope(name + '/summary'): if isinstance(grad, tf.IndexedSlices): var = tf.gather(var, grad.indices) grad = grad.values if var.dtype.is_complex: var = tf.abs(var) grad = tf.abs(grad) if py_utils.IsEagerMode(): histogram_v2(f'var_hist/{name}', var) histogram_v2(f'grad_hist/{name}', grad) else: histogram(f'var_hist/{name}', var) histogram(f'grad_hist/{name}', grad)
def __init__(self, inference_graph, subgraph_name=None, checkpoint=None, device_type="gpu", tf_master=""): assert device_type in ["cpu", "gpu", "tpu"] subgraph_name = subgraph_name or "default" if isinstance(inference_graph, six.string_types): tf.logging.info("Reading inference graph from %s.", inference_graph) inference_graph = LoadInferenceGraph(inference_graph) self._inference_graph = inference_graph self._checkpoint = checkpoint self._device_type = device_type self._tf_master = tf_master self._graph = tf.Graph() with self._graph.as_default(): tf.logging.info("Loading inference graph for prediction.") self._saver = tf.train.Saver(saver_def=inference_graph.saver_def) with tf.device("/%s:0" % "cpu" if device_type == "tpu" else device_type): tf.import_graph_def(inference_graph.graph_def, name="") self._graph.finalize() if inference_graph.subgraphs: if subgraph_name not in inference_graph.subgraphs: raise ValueError( "Subgraph %s not defined. Valid subgraphs: %s" % (subgraph_name, list(inference_graph.subgraphs.keys()))) subgraph = inference_graph.subgraphs[subgraph_name] self._fetches = subgraph.fetches self._feeds = subgraph.feeds else: self._fetches = inference_graph.fetches self._feeds = inference_graph.feeds # Lock for creating new sessions. self._sess_lock = threading.Lock() self._cur_sess_id = 0 self._CreateNewSession()
def _FPropSplitInputBatch(self, theta, input_batch): """Splits the input batch on the input device.""" if py_utils.use_tpu(): return self._FPropTpu(theta, input_batch) cluster = self.cluster num_splits = cluster.num_splits_per_client if not isinstance(input_batch, list): input_batch = [input_batch] assert len(input_batch) == num_splits, (len(input_batch), num_splits) # dev_list_per_replica[i][j] is the i-th worker's j-th device. dev_list_per_replica = cluster.available_devices.tolist() # Asserts invariant of the total number of splits w.r.t., # splits per worker. splits_per_replica = cluster.num_splits_per_replica assert num_splits == splits_per_replica * len(dev_list_per_replica), ( num_splits, splits_per_replica, len(dev_list_per_replica)) all_metrics = [] all_per_example_tensors = [] with cluster: for w_id, w_devs in enumerate(dev_list_per_replica): # Make local copy of the vars, shard on devices for this worker. theta_local = py_utils.CreateLocalTheta( theta, w_devs, label='worker %d' % w_id) for s_id in range(splits_per_replica): # s_id-th split for the w_id-th worker. split_id = splits_per_replica * w_id + s_id with cluster_factory.SetModelSplit(split_id) as c: with tf.device(c.WorkerDeviceInModelSplit(0)): with tf.name_scope('tower_%d_%d' % (w_id, s_id)): batch = input_batch[split_id] metrics, per_example = self.FPropTower(theta_local, batch) all_metrics.append(metrics) all_per_example_tensors.append(per_example) return py_utils.WeightedAvgOfMetrics( all_metrics), py_utils.ConcatPerExampleTensors(all_per_example_tensors)
def testPSRandomSize(self): p = cluster_factory.Cluster.Params() p.worker.name = '/job:trainer' p.ps.name = '/job:ps' p.ps.replicas = 10 c = cluster_factory.Cluster(p) g = tf.Graph() vs = [] np.random.seed(301) with g.as_default(): with tf.device(c.GetPlacer()): # Creates 200 variables with different sizes. for i in range(200): if i % 13: size = np.random.randint(10000) elif i % 7: size = np.random.randint(100) else: size = np.random.randint(10) vs.append(tf.get_variable('x%d' % i, shape=(size))) sum_all = tf.add_n([tf.reduce_sum(x) for x in vs]) # Computes the total size of variables placed on each device. total_size = {} # device name -> size for v in vs: size = tf.TensorShape(v.op.get_attr('shape')).num_elements() if v.device in total_size: total_size[v.device] += size else: total_size[v.device] = size for (device, allocated) in zip( sorted(total_size), [91701, 91361, 90346, 88738, 87240, 89265, 91944, 92472, 88051, 95053]): self.assertEqual(total_size[device], allocated) self.assertEqual( sum_all.device, cluster.MakeDeviceString( job_name='/job:trainer', replica_id=0, task_id=0, device_name='CPU', device_id=0))
def _CreateVariableInternal(self, name, meta): """Immediately creates the variable described by `meta`. DO NOT OVERRIDE. For internal use only. Subclasses of BaseLayer should use self.CreateVariable() to create variables. Args: name: The variable name. meta: A CreateVariableMeta describing the variable to be created. """ meta.kwargs.setdefault('default_seed', self.params.random_seed) var = py_utils.CreateVariable(name, meta.var_params, **meta.kwargs) self._private_vars[name] = var if resource_variable_ops.is_resource_variable(var): value = var else: with tf.device(var.device): value = tf.identity(var) if meta.theta_fn is not None: value = meta.theta_fn(value) self._private_theta[name] = value
def _OutfeedDequeue(self): """Collect outfeed dequeue from all devices.""" num_outfeeds = len(self.metrics_nm.Flatten()) outfeed_ops = [[]] * num_outfeeds device_assignment = py_utils.GetTpuDeviceAssignment() assert device_assignment for replica in range(device_assignment.num_replicas): num_cores_per_replica = 1 if self.spmd else ( device_assignment.num_cores_per_replica) for core in range(num_cores_per_replica): with tf.device(device_assignment.host_device(replica, core)): outfeeds_per_core = tpu_ops.outfeed_dequeue_tuple( dtypes=[x.dtype for x in self.metrics_nm.Flatten()], shapes=[x.shape for x in self.metrics_nm.Flatten()], device_ordinal=device_assignment.tpu_ordinal( replica, core)) for idx_outfeed, out_feed in enumerate(outfeeds_per_core): outfeed_ops[idx_outfeed] = outfeed_ops[idx_outfeed] + [ out_feed ] return [tf.concat(per_outfeed, 0) for per_outfeed in outfeed_ops]
def _CreateVariableInternal(self, name: str, meta: CreateVariableMeta) -> None: """Immediately creates the variable described by `meta`. DO NOT OVERRIDE. For internal use only. Subclasses of BaseLayer should use self.CreateVariable() to create variables. Args: name: The variable name. meta: A CreateVariableMeta describing the variable to be created. """ meta.kwargs.setdefault('default_seed', self.params.random_seed) var = py_utils.CreateVariable(name, meta.var_params, **meta.kwargs) self._private_vars[name] = var if self.cluster.params.worker.gpus_per_replica > 0: # On GPU (which always trains a single step per session.run()), reference # a tensor in FProp to cache it on device and avoid extraneous sends from # reading variables from ps multiple times. with tf.device(var.device): value = tf.identity(var) else: # Pass the resource variable directly into the training loop. value = var # Due to b/174956514, we have to annotate the use of the variable once, # otherwise, the sharding annotation on the var will be ignored. # TODO(yonghui): Get rid of this once b/174956514 is fixed. if (meta.var_params.device_mesh is not None and var.shape.rank == len(meta.var_params.tensor_split_dims_mapping)): value = gshard_utils.MeshSplit( value, meta.var_params.device_mesh, meta.var_params.tensor_split_dims_mapping, use_sharding_op=True) if meta.theta_fn is not None: self._private_theta_fn[name] = meta.theta_fn self._private_theta[name] = value
def ReplicatedGenericInput(processor, num_replicas, replica_device_fn, **kwargs): """Builds a replicated input pipeline. This is similar to GenericInput, except that the input processing can be distributed across devices and then concatenated at the current device. Args: processor: see comments for GenericInput. num_replicas: the number of input processing replicas. Usually set to number of infeed hosts. replica_device_fn: a int -> string function that takes the replica index in range [0, num_replicas) and returns a TF device string, e.g., lambda i: '/task:{}/device:CPU:0'.format(i) **kwargs: additional keyword args for x_ops.generic_input. Returns: A tuple of (outputs, bucket_keys): - outputs: a NestedMap or a list of tensors, similar to `processor`'s return, except every tensor will have an additional dimension 0 that represents the batch dimension. The batch size will be (num_replicas * bucket_batch_limit[...]), i.e., kwargs['bucket_batch_limit'] specifies the per-replica batch size. - bucket_keys: a tf.int32 vector. """ if num_replicas > 1 and 'bucket_batch_limit' in kwargs: assert all(b == max(kwargs['bucket_batch_limit']) for b in kwargs['bucket_batch_limit']) replica_outputs = [] for replica_i in range(num_replicas): replica_device = replica_device_fn(replica_i) with tf.device(replica_device): replica_outputs.append(GenericInput(processor, **kwargs)) output_nmaps, output_bucket_keys = zip(*replica_outputs) concat_nmap = tf.nest.map_structure(lambda *t: tf.concat(t, axis=0), *output_nmaps) concat_bucket_keys = tf.concat(output_bucket_keys, axis=0) return concat_nmap, concat_bucket_keys
def _WrapNonLingvoVars(dest_layer: base_layer.BaseLayer, variables: Collection[tf.Variable], trainable_variables: Collection[tf.Variable] = ()): """Adds variables to the given lingvo layer and appropriate graph collections. This function helps wrap variables created outside of lingvo so they are correctly handled by lingvo's trainer and checkpointer. It does the following: - makes all `variables` trackable through `dest_layer.vars`; - ensures `variables` are in the `tf.global_variables()` graph collection so the trainer can initialize them; - adds the `trainable_variables` subset to the `tf.trainable_variables()` graph collection, so they are visible to the learner (i.e. can be trained). Args: dest_layer: Lingvo layer to add the `variables` to. variables: The non-lingvo variables to wrap. trainable_variables: The subset of `variables` to ensure are trainable. """ global_collection = set(tf.global_variables()) for v in variables: assert v in global_collection name = v.name.split(':')[0] # pylint: disable=protected-access dest_layer._private_vars[name] = v with tf.device(v.device): dest_layer._private_theta[name] = tf.identity(v) # pylint: enable=protected-access trainable_collection = set(tf.trainable_variables()) for v in trainable_variables: if v not in trainable_collection: tf.logging.warning( 'Wrapped var %s not in trainable collection; adding it.', v.name) tf.compat.v1.add_to_collection( tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, v)
def _CreateLayerVariables(self): p = self.params w_pc = py_utils.WeightParams( shape=[self._ids_per_shard, p.embedding_dim], init=p.params_init, dtype=p.dtype, collections=[self.__class__.__name__ + '_vars']) embedding_table_vars = [] for i in range(p.num_tpu_hosts): device_name = self.GetDeviceName(i) with tf.device(device_name), py_utils.outside_all_rewrites(): var_name = self.GetVariableName(i) self.CreateVariable(var_name, w_pc) embedding_var = self.vars[var_name] embedding_table_vars.append(embedding_var) # Remove from _private_vars / _private_thetas to be added later as wm. del self._private_vars[var_name] del self._private_theta[var_name] self._tpu_embedding_collection.AddTableVariables( self.table_name, embedding_table_vars) if not py_utils.use_tpu(): # We don't want to add this for TrainerTpu, otherwise the identity # reference leads to copying the embedding to the TPU for no reason. # However, this is needed for CPU (eval/decode/controller). self._private_vars['wm'] = embedding_table_vars self._private_theta['wm'] = [ tf.identity(v) for v in embedding_table_vars ] # Only trainer and controller need slot variables and load/retrieve ops. if not self.do_eval: self._load_op_list, self._retrieve_op_list = ( self.optimizer.CreateSlotVariablesAndOps( embedding_table_vars, self))
def LoopBody(i, *input_arrays): """Process outfeed data for a single TpuTrainStep. Args: i: current loop index. *input_arrays: One tf.TensorArray per outfeed tensor. Returns: i+1 (new index) plus post-write tf.TensorArray handles. """ # Outfeed ops execute on each JF node, so they must be located on the # nodes. outfeed_devices = [] device_assignment = py_utils.GetTpuDeviceAssignment() assert device_assignment for replica in range(device_assignment.num_replicas): num_cores_per_replica = 1 if self.spmd else ( device_assignment.num_cores_per_replica) for core in range(num_cores_per_replica): with tf.device(device_assignment.host_device( replica, core)): outfeed_devices.append( tpu_ops.outfeed_dequeue_tuple( tensor_types, tensor_shapes, device_ordinal=device_assignment.tpu_ordinal( replica, core))) offset = i * num_devices output_arrays = list(input_arrays) # Each output_array holds a different per-example tensor. We get results # for each tensor from each TPU for each TpuTrainStep call. for j in range(len(output_arrays)): for k in range(len(outfeed_devices)): output_arrays[j] = output_arrays[j].write( offset + k, outfeed_devices[k][j]) return tuple([i + 1] + output_arrays)