def __init__(self, params): super(SeqLayer, self).__init__(params) p = self.params assert p.name num_cells = len(p.cell_tpl) self._before_layers = [] self._cells = [] before_tpl_device = '' cell_devices = [''] * num_cells if py_utils.use_tpu(): cluster = self.cluster before_tpl_device = cluster.WorkerDeviceInModelSplit(0) cell_devices = [ cluster.WorkerDeviceInModelSplit(i) for i in range(num_cells) ] for l in p.before_tpl: with tf.device(before_tpl_device): assert l.name self.CreateChild(l.name, l) self._before_layers.append((l.name, self.children[l.name])) for i, l in enumerate(p.cell_tpl): with tf.device(cell_devices[i]): assert l.name self.CreateChild(l.name, l) self._cells.append((l.name, self.children[l.name]))
def TrainAndDecodeEpoch(i, host_device): """Train and decode infeed for an epoch. Args: i: host index. host_device: host device string Returns: Decode with control deps on train node. """ train_infeed_fn = lambda: self._train_input.CreatePerHostEnqueueOp(i) decode_infeed_fn = lambda: self._decode_input.CreatePerHostEnqueueOp(i) tf.logging.info('self._train_steps_per_loop: %d', self._train_steps_per_loop) tf.logging.info('self._decode_steps_per_loop: %d', self._decode_steps_per_loop) train = wrap_computation_in_while_loop(train_infeed_fn, self._train_steps_per_loop, host_device) with tf.device(host_device): with tf.control_dependencies([train]): decode = wrap_computation_in_while_loop(decode_infeed_fn, self._decode_steps_per_loop, host_device) return decode
def LoopBody(i, *input_arrays): """Process outfeed data for a single TpuTrainStep. Args: i: current loop index. *input_arrays: One tf.TensorArray per outfeed tensor. Returns: i+1 (new index) plus post-write tf.TensorArray handles. """ # Outfeed ops execute on each JF node, so they must be located on the # nodes. outfeed_devices = [] device_assignment = py_utils.GetTpuDeviceAssignment() assert device_assignment for replica in range(device_assignment.num_replicas): for core in range(device_assignment.num_cores_per_replica): with tf.device(device_assignment.host_device(replica, core)): outfeed_devices.append( tpu_ops.outfeed_dequeue_tuple( tensor_types, tensor_shapes, device_ordinal=device_assignment.tpu_ordinal(replica, core))) offset = i * num_devices output_arrays = list(input_arrays) # Each output_array holds a different per-example tensor. We get results # for each tensor from each TPU for each TpuTrainStep call. for j in range(len(output_arrays)): for k in range(len(outfeed_devices)): output_arrays[j] = output_arrays[j].write(offset + k, outfeed_devices[k][j]) return tuple([i + 1] + output_arrays)
def FProp(self, theta, *args): """Round-robin every children cells in cell_tpl among worker devices. Args: theta: A NestedMap object containing weights' values of this layer and its children layers. *args: Input args Returns: A list contains one tensor of [batch_size, feature_height, feature_width, channel]. """ num_layers = len(self.params.cell_tpl) cluster = self.cluster for (name, l) in self._before_layers: l_theta = theta[name] args = _ToTuple(args) args = l.FProp(l_theta, *args) for i in range(num_layers): with tf.device(cluster.WorkerDeviceInModelSplit(i)): cell_name, cell = self._cells[i] args = _ToTuple(args) args = cell.FProp(theta[cell_name], *args) return args
def FProp(self, theta, *args): """FProp through multiple devices in the split. Args: theta: A NestedMap object containing weights' values of this layer and its children layers. *args: A tuple of Tensors (one or more). Every tensor's first dimension is the same (the batch dimension). Returns: The sub layer's output. """ p = self.params with tf.name_scope(p.name): assert all(isinstance(x, tf.Tensor) for x in args) cluster = self.cluster num = cluster.num_devices_per_split if num == 1: return self.sub.FProp(theta.sub, *args) inps = py_utils.SplitRecursively(list(args), num, axis=0) outs = [] for i, xs in enumerate(inps): device = cluster.WorkerDeviceInModelSplit(i) tf.logging.info('%d on device %s', i, device) with tf.device(device): ys = self.sub.FProp(theta.sub, *xs) if isinstance(ys, tuple): outs += [list(ys)] else: outs += [ys] # ys is a single tensor ret = py_utils.ConcatRecursively(outs, axis=0) if isinstance(ret, list): return tuple(ret) else: return ret # ys is a single tensor
def _OutfeedDequeue(self): """Collect outfeed dequeue from all devices.""" num_outfeeds = len(self.metrics_nm.Flatten()) outfeed_dicts = [] concat_lists = {} # Hard-coding for Transformer/MLPerf. keys = ['target_ids', 'eval_weight', 'tlen', 'top_ids', 'top_lens'] concat_dict = {} for key in keys: concat_lists[key] = [] device_assignment = py_utils.GetTpuDeviceAssignment() assert device_assignment for replica in range(device_assignment.num_replicas): num_cores_per_replica = 1 if self.spmd else ( device_assignment.num_cores_per_replica) for core in range(num_cores_per_replica): with tf.device(device_assignment.host_device(replica, core)): outfeeds_per_core = tpu_ops.outfeed_dequeue_tuple( dtypes=[x.dtype for x in self.metrics_nm.Flatten()], shapes=[x.shape for x in self.metrics_nm.Flatten()], device_ordinal=device_assignment.tpu_ordinal(replica, core)) packed = tf.nest.pack_sequence_as(self.metrics_nm, outfeeds_per_core) outfeed_dict = self._decode_model_task.PostProcessDecodeHost(packed) for key in keys: concat_lists[key].append(outfeed_dict[key]) for key in keys: concat_dict[key] = tf.concat(concat_lists[key], 0) return concat_dict
def CreateTpuEmbeddingEnqueueOps(self): """Creates the TpuEmbedding enqueue ops on the host. Note that this must be called after the instantiation of the monolithic TPUEmbeddingLayer. """ p = self.params cluster = self.cluster num_tpu_hosts = cluster.num_tpu_hosts num_infeed_hosts = num_tpu_hosts if p.use_per_host_infeed else 1 tpu_embedding_collection = tf.get_collection(py_utils.TPU_EMBEDDING) tpu_embedding = (tpu_embedding_collection[0] if tpu_embedding_collection else None) enqueue_ops = [] if num_tpu_hosts > 1 and tpu_embedding is not None: if not p.use_per_host_infeed: tf.logging.fatal( 'TPU Embedding must be used with per_host_infeed with multiple ' 'TPU host topologies.') tpu_emb_input_keys = (list(tpu_embedding.feature_to_config_dict.keys()) if tpu_embedding is not None else []) tf.logging.info('tpu_emb_input_keys: %r', tpu_emb_input_keys) if not tpu_embedding: return for task_id in range(num_infeed_hosts): host_device = '/task:{}/device:CPU:0'.format(task_id) with tf.device(host_device): if isinstance(self._batch, py_utils.NestedMap): # Hack: bucket_keys and xxx.bucket_keys are not needed on TPU. # Note that when MultiTaskData is used, bucket_keys will be at the # second level of the dictionary. self._batch = self._batch.FilterKeyVal( lambda k, _: not k.endswith('bucket_keys')) tf.logging.info('host_device: %s, batch: %r', host_device, self._batch) enqueue_dict_per_core = [ {} for _ in range(tpu_embedding.num_cores_per_host) ] num_cores_per_host = tpu_embedding.num_cores_per_host for key in tpu_emb_input_keys: feat = self._batch[key] tpu_emb_feat_splitted = tf.split(feat, num_cores_per_host) for core, split in enumerate(tpu_emb_feat_splitted): # Dense to sparse. Note the assumption of a padding id. sample_indices = tf.where(tf.not_equal(split, -1)) embedding_indices = tf.gather_nd(split, sample_indices) enqueue_data = tpu_embedding_lib.EnqueueData( embedding_indices, sample_indices) enqueue_dict_per_core[core][key] = enqueue_data enqueue_ops += tpu_embedding.generate_enqueue_ops( enqueue_dict_per_core) self._tpu_infeed_op.append(tf.group(*enqueue_ops))
def _DecodeStep(): """Decode call to be compiled for TPU.""" input_batch = self._model_task.input_generator.TpuDequeueBatch() metrics_dict = self._model_task.Decode(input_batch) self.metrics_nm = py_utils.NestedMap(metrics_dict) device = tpu.core(0) if self.spmd else '' with tf.device(device): outfeed_enqueue = tpu_ops.outfeed_enqueue_tuple( self.metrics_nm.Flatten()) return [outfeed_enqueue]
def wrap_computation_in_while_loop(op_fn, n, host_device): """Wraps the ops generated by `op_fn` in tf.while_loop.""" def computation(i): ops = op_fn() if not isinstance(ops, list): ops = [ops] with tf.control_dependencies(ops): return tf.Print(i + 1, [i], 'while_loop:') with tf.device(host_device): return tf.while_loop( lambda i: tf.less(i, n), computation, [tf.constant(0)], parallel_iterations=1)
def TpuDequeueBatch(self): """Create TPU dequeue ops. This should only be called within a TPU context. Returns: - A NestedMap of the input batch. """ assert self._tpu_queues, 'CreateTpuEnqueueOps must be called first.' with tf.device(tf.tpu.core(0)): # Note that the dequeue_tuple op on the TPU core # only cares about the shape/types being dequeued # which is why this is hard-coded to the first Queue. tensors = self._tpu_queues[0].generate_dequeue_op() return self._batch.Pack(tensors)
def CollectVarHistogram(vs_gs): """Adds histogram summaries for variables and gradients.""" for name, (var, grad) in vs_gs.FlattenItems(): name = py_utils.SanitizeScopeKey(name) with tf.device(var.device), tf.name_scope(name + '/summary'): if isinstance(grad, tf.IndexedSlices): var = tf.gather(var, grad.indices) grad = grad.values if var.dtype.is_complex: var = tf.abs(var) grad = tf.abs(grad) histogram('var_hist/' + name, var) histogram('grad_hist/' + name, grad)
def _DecodeStep(): """Decode call to be compiled for TPU.""" with py_utils.OpportunisticVariableReuseScope(True): with cluster_factory.SetEval(True): self._decode_model = self._decode_task_params.Instantiate() self._decode_model_task = self._decode_model.GetTask() self._decode_model_task.AddChild('input', self._decode_input) input_batch = self._decode_model_task.input_generator.TpuDequeueBatch( ) metrics_dict = self._decode_model_task.Decode(input_batch) self.metrics_nm = py_utils.NestedMap(metrics_dict) device = tpu.core(0) if self.spmd else '' with tf.device(device): outfeed_enqueue = tpu_ops.outfeed_enqueue_tuple( self.metrics_nm.Flatten()) return [outfeed_enqueue]
def Send(self, tensor): """Sends a tensor through the channel.""" assert tensor.dtype == self._dtype assert not self._send_called, ("Send called multiple times for %s" % self._name) self._send_called = True if self._send_tpu_core == -1: return tf.raw_ops.Send( tensor=tensor, tensor_name=self._name, send_device=self._send_device, send_device_incarnation=0, recv_device=self._recv_device) else: with tf.device(self._send_device): return xla.send( tensor, tensor_name=self._name, name="Send_" + self._name)
def _OutfeedDequeue(self): """Collect outfeed dequeue from all devices.""" num_outfeeds = len(self.metrics_nm.Flatten()) outfeed_ops = [[]] * num_outfeeds device_assignment = py_utils.GetTpuDeviceAssignment() assert device_assignment for replica in range(device_assignment.num_replicas): num_cores_per_replica = 1 if self.spmd else ( device_assignment.num_cores_per_replica) for core in range(num_cores_per_replica): with tf.device(device_assignment.host_device(replica, core)): outfeeds_per_core = tpu_ops.outfeed_dequeue_tuple( dtypes=[x.dtype for x in self.metrics_nm.Flatten()], shapes=[x.shape for x in self.metrics_nm.Flatten()], device_ordinal=device_assignment.tpu_ordinal(replica, core)) for idx_outfeed, out_feed in enumerate(outfeeds_per_core): outfeed_ops[idx_outfeed] = outfeed_ops[idx_outfeed] + [out_feed] return [tf.concat(per_outfeed, 0) for per_outfeed in outfeed_ops]
def Recv(self): """Receives a tensor from the channel.""" if self._send_tpu_core == -1: received = tf.raw_ops.Recv( tensor_type=self._dtype, tensor_name=self._name, send_device=self._send_device, send_device_incarnation=0, recv_device=self._recv_device) received.set_shape(self._shape) return received else: with tf.device(self._recv_device): return xla.recv( self._dtype, tensor_name=self._name, shape=self._shape, name="Recv_" + self._name)
def Export(cls, model_cfg, model_task_name=None, device_options=InferenceDeviceOptions( device='', retain_device_placement=False, var_options=None, gen_init_op=True, dtype_override=None), freeze_checkpoint=None, freeze_defaults=False, export_path=None, subgraph_filter=None, random_seed=None, disable_packed_input=True): """Exports a InferenceGraph proto with piecewise subgraphs. Sets FLAGS.enable_asserts to False unless user explicitly sets it to True. Args: model_cfg: a Params instance as returned by model_registry.GetParams(modelname, 'Test') or model_params.Model(). model_task_name: The task to generate an inference graph for. Should be None for single-task models. device_options: Device options for the accelerator used for serving. freeze_checkpoint: The checkpoint to load. Loads and freezes the model if given. freeze_defaults: Default initializes the graph and freeze. Useful for early testing of downstream tools without having a checkpoint. export_path: If not None, write the inference graph in ASCII to this path. subgraph_filter: A list of subgraph names. If not None or empty, export only this list of inference subgraphs. random_seed: Fixes the random seed in the exported inference graph. disable_packed_input: Disable packed input for inference writing purposes. Returns: InferenceGraph proto. Raises: ValueError: if the model does not support the listed subgraphs. """ assert issubclass(model_cfg.cls, base_model.BaseModel) # Disable assertions unless user explicitly enables it. if FLAGS['enable_asserts'].using_default_value: FLAGS.enable_asserts = False # TODO(laurenzo): Work out how much we need to specify here in terms of # cluster configuration. cls._SetClusterParams(model_cfg.cluster, device_options) # Configure the model. model_cfg.random_seed = random_seed model_cfg.is_inference = True if disable_packed_input: def _DisablePackedInput(task): if (_ParamExists(task, 'encoder') and _ParamExists(task.encoder, 'packed_input')): task.encoder.packed_input = False if (_ParamExists(task, 'decoder') and _ParamExists(task.decoder, 'packed_input')): task.decoder.packed_input = False if issubclass(model_cfg.cls, base_model.MultiTaskModel): for _, task_param in model_cfg.task_params.IterParams(): _DisablePackedInput(task_param) else: _DisablePackedInput(model_cfg.task) tf.logging.info('Model %s params:', model_cfg.name) for line in model_cfg.ToText().split('\n'): tf.logging.info('%s', line) # Instantiate the graph. graph = tf.Graph() with graph.as_default(): tf.random.set_seed(random_seed) cluster = model_cfg.cluster.Instantiate() device = cluster.GetPlacer() tpu_const_scope = _DummyScope() if (IsTpu(device_options) and device_options.var_options == 'AS_CONSTANTS'): # Do not specify devices for variables if we are marking them as # constants. device = '' tpu_const_scope = ConstGuaranteeScope() with cluster, tf.device(device), tpu_const_scope: bfloat16_override = ShouldForceBfloat16ForWeightsAndActivations( device_options) if bfloat16_override: py_utils.UpdateDtype(model_cfg, tf.bfloat16) py_utils.UpdateFpropDtype(model_cfg, tf.bfloat16) # Hard-code TPU-related flags prior to instantiating model. old_enable_asserts = FLAGS.enable_asserts old_xla_device = FLAGS.xla_device if IsTpu(device_options): FLAGS.enable_asserts = False FLAGS.xla_device = 'tpu' # Ensure the global_step variable is created. global_step_var = py_utils.GetOrCreateGlobalStepVar() global_step = tf.identity(global_step_var, name='global_step_tensor') with py_utils.GlobalStepContext(global_step): try: mdl = model_cfg.Instantiate() variables_to_restore = (_MakeVariableDictionary( tf.global_variables()) if not mdl.ema else mdl.ema.variables_to_restore( mdl.variables_for_ema)) if bfloat16_override: saver_var_spec = ( bfloat16_variables. get_saver_spec_for_variables_with_bf16_overrides( variables_to_restore)) else: saver_var_spec = variables_to_restore saver = tf.train.Saver(saver_var_spec) tf.variables_initializer(tf.global_variables(), name='init_all_variables') if IsTpu( device_options) and device_options.gen_init_op: tf.group(tf.tpu.initialize_system(), name='tpu_init_op') model_task = mdl.GetTask(model_task_name) inference_graph_proto = inference_graph_pb2.InferenceGraph( ) subgraphs_proto = model_task.Inference() if isinstance(subgraphs_proto, dict): subgraphs_proto = ConvertSubgraphDictToProto( subgraphs_proto) for name, subgraph in subgraphs_proto.subgraphs.items( ): if not subgraph_filter or name in subgraph_filter: inference_graph_proto.subgraphs[name].CopyFrom( subgraph) # Add a table init op and global variable init op to the graph. # Tables can be declared anywhere in the graph, so this op has to be # added last. tf.tables_initializer(name='init_all_tables') finally: # Reset TPU-related flags after model instantiation. FLAGS.enable_asserts = old_enable_asserts FLAGS.xla_device = old_xla_device tf.logging.info('Graph contains ops: %r', [op.name for op in graph.get_operations()]) inference_graph_proto.saver_def.CopyFrom(saver.as_saver_def()) # Freezing. if freeze_defaults or freeze_checkpoint: output_op_names = GetOutputOpNames(graph, inference_graph_proto, preserve_colocation_nodes=False) if cls._DeviceSupportsFreezing(device_options): raise ValueError( 'freeze_checkpoint cannot be used with device ' + device_options.device) if freeze_checkpoint: tf.logging.info('Freezing graph from checkpoint: %s', freeze_checkpoint) graph_def = _FreezeGraphFromCheckpoint(graph, saver, freeze_checkpoint, output_op_names) elif freeze_defaults: tf.logging.info('Default initializing graph and freezing.') graph_def = _FreezeDefaults(graph, output_op_names) else: output_op_names = GetOutputOpNames(graph, inference_graph_proto) # Prune the graph to just the parts we need. # To support restoring, we have to not prune out the restore node. output_op_names.append('init_all_tables') output_op_names.append('init_all_variables') output_op_names.append('save/control_dependency') output_op_names.append('save/restore_all') if IsTpu(device_options) and device_options.gen_init_op: output_op_names.append('tpu_init_op') graph_def = graph.as_graph_def() tf.logging.info('Pruning graph to output ops: %r', output_op_names) graph_def = tf.graph_util.extract_sub_graph( graph_def, output_op_names) if not device_options.retain_device_placement: # Clear the device so that the runtime can choose. tf.logging.info('Clearing device placement for: %s', device_options.device) for node in graph_def.node: node.ClearField('device') for function in graph_def.library.function: for node_def in function.node_def: node_def.ClearField('device') inference_graph_proto.graph_def.CopyFrom(graph_def) if export_path: with tf.io.gfile.GFile(export_path, 'w') as f: f.write(text_format.MessageToString(inference_graph_proto)) return inference_graph_proto
def __init__(self, params): with tf.device(self.cluster.input_device): super(_UseInputDevice, self).__init__(params)
def SplitInputBatch(self, num_splits): with tf.device(self.cluster.input_device): return super(_UseInputDevice, self).SplitInputBatch(num_splits)
def FProp(self, theta, *args): """Run multiple cells in different devices in a pipelining manner. Args: theta: A NestedMap object containing weights' values of this layer and its children layers. *args: Non-keyworded variable length argument list of input tensors. Returns: A list of output tensors """ # TODO(huangyp): handle optional None inputs. p = self.params if self.do_eval: outputs = copy.copy(args) for (name, l) in self._before_layers + self._cells: outputs = _ToTuple(outputs) outputs = l.FProp(theta[name], *outputs) return outputs num_cells = len(p.cell_tpl) cluster = self.cluster # Compute shapes of input and output tensors. input_shapes = self._get_input_shapes(*args) state_dtype = self._get_state_dtype(*args) state_shapes = self._CalculateOutputShapes(input_shapes) tf.logging.info('state_shapes={}'.format(state_shapes)) def GetCellFn(i): """Get the ith feature extraction layer.""" def CellFn(theta, state0, inputs): """A cell fn is exectued inside of StackedRecurrent.""" del state0 def _FPropInputSetShape(name, t_shape): if t_shape is None: return None inputs[name].set_shape(t_shape.ToTensorShape().as_list()) return inputs[name] if p.nested_map_fprop: # pylint: disable=protected-access fprop_inputs = state_shapes[i]._RecursiveMap( _FPropInputSetShape) # pylint: enable=protected-access else: fprop_inputs = [] for input_idx, input_shape in enumerate(state_shapes[i]): name = 's{}'.format(input_idx) fprop_inputs.append( _FPropInputSetShape(name, input_shape)) with py_utils.RemoveAssertContext(remove=True): with CellFnFPropOpReplacementWrapper(): tf.logging.info('cell {} input {}'.format( i, fprop_inputs)) mb_tensor = inputs[_MICRO_BATCH_STATE_NAME] SetOverWriteGlobalStep(mb_tensor) _, cell = self._cells[i] fprop_inputs = _ToTuple(fprop_inputs) outputs = cell.FProp(theta, *fprop_inputs) if p.nested_map_fprop: assert py_utils.IsCompatible(outputs, state_shapes[i + 1]) state1 = outputs.Filter(lambda x: x is not None) else: state1 = py_utils.NestedMap() outputs = _ToTuple(outputs) assert len(outputs) == len(state_shapes[i + 1]) for output_idx in range(len(outputs)): if outputs[output_idx] is not None: name = 's{}'.format(output_idx) state1[name] = outputs[output_idx] state1[_MICRO_BATCH_STATE_NAME] = mb_tensor return state1, py_utils.NestedMap() return CellFn cell_fns = [] accumulator_layers = [] thetas = [] init_states = [] devices = [] for cell_idx in range(num_cells): cell_name, cell = self._cells[cell_idx] accumulator_layers.append(cell) cell_fns.append(GetCellFn(cell_idx)) thetas.append(theta[cell_name]) def _TfZeros(t_shape): if t_shape is None: return None return tf.zeros(t_shape.ToTensorShape().as_list(), dtype=state_dtype) if p.nested_map_fprop: init_state = py_utils.Transform(_TfZeros, state_shapes[cell_idx + 1]) init_state = init_state.Filter(lambda x: x is not None) else: init_state = py_utils.NestedMap() for output_idx, state in enumerate(state_shapes[cell_idx + 1]): state = _TfZeros(state) if state is not None: name = 's{}'.format(output_idx) init_state[name] = state init_state[_MICRO_BATCH_STATE_NAME] = tf.cast(0, dtype=state_dtype) init_states.append(init_state) devices.append(cluster.WorkerDeviceInModelSplit(cell_idx)) cell_grads = [None] * num_cells cell_outs = [lambda x: x] * num_cells cell_out_grads = [lambda x: x] * num_cells with tf.device(devices[0]): previous = _ToTuple(args) for (name, l) in self._before_layers: previous = l.FProp(theta[name], *previous) previous = _ToTuple(previous) def _StackAndSplit(x): # Split tensors into microbatches. if x is None: return None return tf.stack( tf.split(x, p.num_micro_batches, axis=p.batch_dim)) if p.nested_map_fprop: inputs = py_utils.Transform(_StackAndSplit, previous[0]) inputs = inputs.Filter(lambda x: x is not None) else: inputs = py_utils.NestedMap() for output_idx, output_tensor in enumerate(previous): output_tensor = _StackAndSplit(output_tensor) if output_tensor is not None: name = 's{}'.format(output_idx) inputs[name] = output_tensor gs_tensor = py_utils.GetGlobalStep() inputs[_MICRO_BATCH_STATE_NAME] = tf.stack([ tf.cast(gs_tensor * p.num_micro_batches + t, dtype=state_dtype) for t in range(p.num_micro_batches) ]) tf.logging.info('pipeline input = {}'.format(inputs)) output_state, _ = recurrent.StackedRecurrent( devices=devices, cell_fns=cell_fns, cell_grads=cell_grads, cell_outs=cell_outs, cell_out_grads=cell_out_grads, thetas=thetas, init_states=init_states, inputs=inputs, accumulator_layers=accumulator_layers, unused_acc_state=True) with tf.device(devices[-1]): def _ReshapeRetVal(name, t_shape): """Restore shape for tensors in microbatches.""" if t_shape is None: return None output_tensor = output_state[name] if p.batch_dim != 0: perm = list(range(1, p.batch_dim + 1)) + [0] perm += list(range(p.batch_dim + 1, t_shape.rank + 1)) output_tensor = tf.transpose(output_tensor, perm=perm) output_shape = t_shape.ToTensorShape().as_list() output_shape[p.batch_dim] *= p.num_micro_batches output_tensor = tf.reshape(output_tensor, output_shape) return output_tensor # Construct the final return values from output_state. if p.nested_map_fprop: # pylint: disable=protected-access output_tensors = state_shapes[-1]._RecursiveMap(_ReshapeRetVal) # pylint: enable=protected-access else: output_tensors = [] for output_idx, state_shape in enumerate(state_shapes[-1]): output_name = 's{}'.format(output_idx) output_tensor = _ReshapeRetVal(output_name, state_shape) output_tensors.append(output_tensor) if len(output_tensors) == 1: output_tensors = output_tensors[0] else: output_tensors = tuple(output_tensors) tf.logging.info('pipeline output = {}'.format(output_tensors)) return output_tensors
def _DecoderDevice(self): """Returns the device to run the decoder computation.""" return tf.device('')
def __init__(self, train_cfg, ps_params_dict, model_task_name, logdir, tf_master, **kwargs): """Construct an ExecutorTpu BaseRunner. Args: train_cfg: SingleTaskModelParams or MultiTaskModelParams ps_params_dict: A dict of top-level task name -> ProgramSchedule params, if train_cfg is a SingleTaskModelParams, we expect only one entry. model_task_name: An override for multi-task models, currently unused. logdir: String path to the log directory to output to. tf_master: String path to the master job, e.g. 'local'. **kwargs: keyword args to pass through to BaseRunner. """ super(ExecutorTpu, self).__init__(train_cfg, model_task_name, logdir, tf_master, **kwargs) self._cluster_def = self._cluster.worker_cluster_def # There is a single Executor task assert self._cluster.num_replicas == 1 data_parallelism = self._cluster.num_splits_per_client assert data_parallelism num_devices_per_split = self._cluster.num_devices_per_split tf.logging.info('data_parallelism: %d, num_devices_per_split: %d', data_parallelism, num_devices_per_split) self.task_scheduler = None self._checkpoint_dir = os.path.join(logdir, 'train') self._variable_renaming_rules = [] self._ml_perf = None # If this is a multi-task model, grab the params for the TaskScheduler. if issubclass(train_cfg.cls, base_model.SingleTaskModel): tf.logging.info('single_task_model') assert len(ps_params_dict) == 1 self._model_task_name = list(ps_params_dict.keys())[0] self._single_task_mode = True elif issubclass(train_cfg.cls, base_model.MultiTaskModel): tf.logging.info('multi_task_model') if issubclass(train_cfg.cls, multitask_model.RegExSharedVariableModel): self._variable_renaming_rules = train_cfg.variable_renaming_rules if train_cfg.task_schedule is None: task_schedule_params = task_scheduler.ConstantScheduler.Params( ) task_schedule_params.task_probs = sorted( list(train_cfg.task_probs.IterParams())) else: task_schedule_params = train_cfg.task_schedule self.task_scheduler = task_schedule_params.Instantiate() self._single_task_mode = False else: tf.logging.fatal( 'Model %s is not a sub-class of SingleTaskModel or MultiTaskModel', train_cfg.cls) tf.logging.info('train_cfg.cls: %s', train_cfg.cls) self._WriteToLog(train_cfg.ToText(), self._checkpoint_dir, 'params.txt') self._program_schedule_dict = {} self._programs = [] for task_string, program_schedule_params in ps_params_dict.items(): program_schedule_params.logdir = logdir program_schedule_params.num_splits_per_client = data_parallelism program_schedule_params.task_name = task_string ps = program_schedule_params.Instantiate() self._program_schedule_dict[task_string] = ps tf.logging.info('program_schedule_params: %s', program_schedule_params.ToText()) self._programs += ps.Programs() if program_schedule_params.ml_perf.benchmark_name is not None: self._ml_perf = program_schedule_params.ml_perf tf.logging.info('num_programs: %d', len(self._programs)) if self._ml_perf is not None: self._ml_perf_log = True mlp_log.mlperf_print(key='benchmark', value=self._ml_perf.benchmark_name) else: self._ml_perf_log = False # BaseRunner legacy self.enqueue_ops = None @py_utils.RetryOnTransientTfError() def _WaitTillInit(): """Wait until the model is ready.""" try: with self._graph.as_default(), self._GetSession( cluster_def=self._cluster_def, disable_meta_optimizer=FLAGS. disable_meta_optimizer_in_executor) as sess: topology = sess.run( tf.tpu.initialize_system(embedding_config=None, job=None)) device_assignment = device_assignment_lib.device_assignment( topology, computation_shape=py_utils.ComputationShape( num_devices_per_split), num_replicas=data_parallelism) py_utils.SetTpuDeviceAssignment(device_assignment) tf.logging.info('device_assignment.core_assignment: %s', str(device_assignment.core_assignment)) tf.logging.info( 'device_assignment.topology.device_coordinates: %s', str(device_assignment.topology.device_coordinates)) except py_utils.transient_tf_errors as e: tf.logging.info('TPU initialization failed: %s', e) raise if self._ml_perf_log: mlp_log.mlperf_print(key='cache_clear', value=True) mlp_log.mlperf_print(key='init_start', value=None) _WaitTillInit() with self._graph.as_default(), tf.container(self._container_id): tf.logging.info('self._cluster.job_spec.name: %s', self._cluster.job_spec.name) with self._cluster, tf.device( self._cluster.job_spec.name if not FLAGS. cluster_placer_in_executor else self._cluster.GetPlacer()): with py_utils.VariableRenameScope( self._variable_renaming_rules): _ = py_utils.GetOrCreateGlobalStepVar() for program in self._programs: program.BuildTpuSubgraph() for program in self._programs: program.SetStatusMessageFn(self._SetStatusMessage) program.CreateCheckpointer() self._initialize_tables = tf.tables_initializer() self._initialize_local_vars = tf.local_variables_initializer() self.save_only_checkpointer = checkpointer.Checkpointer( self._checkpoint_dir, model=None, train_params=train_cfg.train, save_only=True)