def get_zero_batch(batch_size=None, max_len=None, key_size=2, return_tgt_mask=False): """Returns zero batch. Args: batch_size: batch size. max_len: max length. key_size: key size. return_tgt_mask: if to return tgt_mask. Returns: a tuple of tensors key: int32 tensor [batch_size, key_size] tgt_id: int32 tensor [batch_size, max_len] tgt_segment_id: float32 tensor [batch_size, max_len] tgt_segment_pos: int32 tensor [batch_size, max_len] tgt_labels: int32 tensor [batch_size, max_len] tgt_sample_temperature: float32 tensor [batch_size] tgt_mask: optional float32 tensor [batch_size, max_len, max_len] """ batch = preload_zero(n=1, batch_size=batch_size, max_len=max_len, key_size=key_size) batch = py_utils.Transform(lambda x: np.squeeze(x, 0), batch) if return_tgt_mask: tgt_mask = np.zeros([batch_size, max_len, max_len], np.float32) batch = (*batch, tgt_mask) return batch
def testGraphLayer(self): g = tf.Graph() with g.as_default(), self.SetEval(True): tf.random.set_seed(24332) def _FnMeta(*shapes): return py_utils.NestedMap(flops=1, out_shapes=shapes) p = layers.GraphLayer.Params().Set( name='graph', input_endpoints=['x'], output_endpoints=['y'], sub=[ ('x.a->y.c', layers.FnLayer.Params().Set(fn=lambda x: 2 * x, fn_meta=_FnMeta)), ('x.b->y.d', layers.FnLayer.Params().Set(name='bar', fn=lambda x: x + 2, fn_meta=_FnMeta)), ('y.c,y.d->y.e, y.f', layers.FnLayer.Params().Set(name='baz', fn=lambda x, y: (x + y, x - y), fn_meta=_FnMeta)), ]) l = p.Instantiate() x = py_utils.NestedMap(a=tf.constant(1.0), b=tf.constant(2.0)) y = l.FProp(l.theta, x) y_shape = l.FPropMeta( p, py_utils.Transform(lambda t: tshape.Shape(t.shape), x)).out_shapes[0] self.assertDictEqual( py_utils.Transform(lambda t: t.shape.as_list(), y), py_utils.Transform(lambda t: t.ToTensorShape().as_list(), y_shape)) with self.session(graph=g) as sess: sess.run(tf.global_variables_initializer()) y_val = sess.run(y) print(y_val) self.assertEqual(py_utils.NestedMap(c=2.0, d=4.0, e=6.0, f=-2.0), y_val)
def testEmptySequentialLayerFPropMeta(self): g = tf.Graph() with g.as_default(): p = layers.SequentialLayer.Params().Set(name='seq') l = p.Instantiate() x = py_utils.NestedMap(val=tf.random.normal(shape=[2, 32])) y = l.FPropDefaultTheta(x) self.assertIsInstance(y.val, tf.Tensor) y_shape = l.FPropMeta( p, py_utils.Transform(lambda t: tshape.Shape(t.shape), x)).out_shapes[0] self.assertEqual(y.val.shape.as_list(), y_shape.val.ToTensorShape().as_list())
def _InternalGetTheta(self): ret = py_utils.Transform(lambda x: x.theta, self.children) private_theta = self._private_theta # When ExecutorTpu specifies the EMA (e.g. when running eval/decode program # with EMA enabled), use the EMA version of the variables if applicable. if self.cluster.is_executor_tpu and self.do_eval and self.ema: vars_loaded_as_ema = self.params.is_inference or ( self.do_eval and not py_utils.use_tpu()) assert not vars_loaded_as_ema, ( 'Not able to use EMA variables since the layer variables are ' 'potentially already loaded as EMA variables.') def MaybeUseEmaVar(x): if not isinstance(x, tf.Variable): raise ValueError( 'EMA is used but self._private_theta contains ' f'non-variables: {x}.') ema_x = self.ema.average(x) return ema_x if ema_x is not None else x private_theta = py_utils.Transform(MaybeUseEmaVar, private_theta) if (self._params.fprop_dtype is not None and self._params.fprop_dtype != self._params.dtype): def MaybeCastToFPropDtype(x): # Need to check `.base_dtype` as x.dtype may be tf.float32_ref. if x is not None and x.dtype.base_dtype == self._params.dtype: return tf.cast(x, self._params.fprop_dtype) else: return x private_theta = py_utils.Transform(MaybeCastToFPropDtype, private_theta) ret.update(private_theta) return ret
def _TransformVarsInternal(self, fn): """Internal: replaces each variable v in self._private_vars with fn(v). Also recursively invokes _TransformVarsInternal() on self.children. Args: fn: A function that takes a variable and returns a variable or a wrapper of the variable. """ self._private_vars_transform_restore_stack.append(self._private_vars) self._private_vars = { key: fn(x) for key, x in self._private_vars.items() } py_utils.Transform( lambda c: c._TransformVarsInternal(fn), # pylint: disable=protected-access self.children)
def CreateChildren( self, name: str, params: Union[List[BaseLayerParamsT], Mapping[str, BaseLayerParamsT]] ) -> None: """Create a list or dict of sub layers. The created sub layer list can be accessed by `name`. E.g.:: self.CreateChildren('foo', ...) self.foo[10].FProp... or:: self.children['foo'][10].Fprop... self.children.foo[10].Fprop... Args: name: The name for the sub layers, which is used as the key into vars/theta. params: a list or dict of `Hyperparams` objects to create. """ if hasattr(self, '_disable_create_child') and self._disable_create_child: raise ValueError( 'Attempting to call CreateChildren outside of __init__.') self._CheckName(name) uid = itertools.count() def Instantiate(p): p = self.CopyBaseParams(self.params, p.Copy()) if not p.name: p.name = '%s_%d' % (name, next(uid)) return p.Instantiate() with self._CreateChildContext(name): self._private_children[name] = py_utils.Transform( Instantiate, params)
def _CalculateOutputShapes(self, input_shapes): """Calcuate the output shape of intermediate layers. Given the FPropMeta function in each FeatureExtractionLayer, calcuates the shapes of outputs of that layer. This is used to recover the shape information in StackedRecurrent. Args: input_shapes: NestedMap or tuple of input TensorShapes. Returns: Return a list of K + 1 NestedMaps or lists of tShape where K is the number of partitions. """ p = self.params shapes = [] # Converts TensorShape to tshape.Shape. def _ToTShape(x): if x is None: return None return tshape.Shape(x.as_list()) shapes = py_utils.Transform(_ToTShape, input_shapes) shapes = _ToTuple(shapes) state_shapes = [] for (_, cell) in self._before_layers: shapes = cell.FPropMeta(cell.params, *shapes).out_shapes state_shapes.append(shapes[0] if p.nested_map_fprop else shapes) for (_, cell) in self._cells: shapes = cell.FPropMeta(cell.params, *shapes).out_shapes state_shapes.append(shapes[0] if p.nested_map_fprop else shapes) return state_shapes
def fetch_shapes(self): # Conversion from dict to NestedMap required. return py_utils.Transform( lambda x: self._graph.get_tensor_by_name(x).shape.as_list(), py_utils.NestedMap(self._fetches))
def _VarNamesDebugString(vars_): return py_utils.Transform(lambda x: x.name, vars_).DebugString()
def _DecodeOnce(self, sess=None, path=''): """Decode a single checkpoint.""" with self._cluster: # Attempt to restore the checkpoint self._checkpointer.RestoreFromPath(checkpoint_path=path) global_step = self._model.global_step.numpy() if global_step < self._task.params.eval.start_decoder_after: return if self._task.input.params.resettable: tf.logging.info('Resetting input_generator.') self._task.input_generator.Reset() dec_metrics = self._task.CreateDecoderMetrics() if not dec_metrics: tf.logging.info('Empty decoder metrics') return buffered_decode_out = [] num_samples_metric = dec_metrics['num_samples_in_batch'] samples_per_summary = self._task.params.eval.decoder_samples_per_summary if samples_per_summary is None: samples_per_summary = self._task.params.eval.samples_per_summary if samples_per_summary == 0: assert self._task.input.params.resettable start_time = time.time() while samples_per_summary == 0 or (num_samples_metric.total_value < samples_per_summary): try: tf.logging.info('Fetching dec_output.') fetch_start = time.time() # Decoder calls FProp multiple times for each checkpoint. Multiple # summaries at the same step is often confusing. Instead, models # should generate aggregate summaries using PostProcessDecodeOut. # Other types of summaries (images, audio etc.) will be generated for # the first batch only. is_first_loop = num_samples_metric.total_value == 0 decode_fn = ( self._decode_fn_with_summary if is_first_loop else self._decode_fn) input_batch, dec_output = decode_fn() for key in self._task.input_generator.GetCpuPassthroughKeys(): if key in input_batch: if key in dec_output: tf.logging.warning( f'Key {key} already present in decode output. ' f'Not adding from input batch.') else: dec_output[key] = input_batch[key] dec_output = py_utils.Transform(lambda x: x.numpy(), dec_output) post_process_start = time.time() tf.logging.info('Done fetching (%f seconds)' % (post_process_start - fetch_start)) decode_out = self._task.PostProcessDecodeOut(dec_output, dec_metrics) if decode_out: if isinstance(decode_out, dict): decode_out = decode_out.items() if is_first_loop: # Add summaries only for the first batch of data. with self._summary_writer.as_default(): for key, value in decode_out: if isinstance(value, tf.Summary): tf.logging.info(f'Adding summary {key} with tags ' f'{[x.tag for x in value.value]}.') tf.compat.v2.summary.experimental.write_raw_pb( tf.constant(value.SerializeToString()), global_step) buffered_decode_out.extend( kv for kv in decode_out if not isinstance(kv[1], tf.Summary)) tf.logging.info( 'Total examples done: %d/%d ' '(%f seconds decode postprocess)', num_samples_metric.total_value, samples_per_summary, time.time() - post_process_start) except tf.errors.OutOfRangeError: if not self._task.input.params.resettable: raise break tf.logging.info('Done decoding ckpt: %s', path) elapsed_secs = time.time() - start_time example_rate = num_samples_metric.total_value / elapsed_secs msg = 'step:%6d, elapsed_secs: %0.2f, examples/sec: %0.2f' % ( global_step, elapsed_secs, example_rate) with self._summary_writer.as_default(): tf.compat.v2.summary.scalar( 'decode_secs', elapsed_secs, step=global_step) tf.compat.v2.summary.scalar( 'examples/sec', example_rate, step=global_step) tf.compat.v2.summary.scalar( 'total_samples', num_samples_metric.total_value, step=global_step) for key, metric in sorted(dec_metrics.items()): msg += ' %s:%.8g' % (key, metric.value) tf.compat.v2.summary.scalar(key, metric.value, step=global_step) self._summary_writer.flush() self._SetStatusMessage(msg) self._ExportMetrics( # Metrics expects python int, but global_step is numpy.int64. decode_checkpoint=int(global_step), dec_metrics=dec_metrics, example_rate=example_rate) decode_out_path = self.GetDecodeOutPath(self._decoder_dir, global_step) decode_finalize_args = base_model.DecodeFinalizeArgs( decode_out_path=decode_out_path, decode_out=buffered_decode_out) self._task.DecodeFinalize(decode_finalize_args)
def subgraph_feed_shapes(self, subgraph_name): # Conversion from dict to NestedMap required. return py_utils.Transform( lambda x: self._graph.get_tensor_by_name(x).shape.as_list(), py_utils.NestedMap(self._get_subgraph_feeds(subgraph_name)))
def AddIdentityToTheta(layer): # pylint: disable=protected-access layer._private_theta = py_utils.Transform(tf.identity, layer._private_theta) # pylint: enable=protected-access layer.children.Transform(AddIdentityToTheta)
def accumulators(self): """Returns `.NestedMap` of `Accumulator` instances for this and children.""" ret = py_utils.Transform(lambda x: x.accumulators, self.children) for k, acc in self._private_accumulators.items(): ret[k] = acc return ret
def _UndoTransformVarsInternal(self): """Internal. Undoes _TransformVarsInternal().""" self._private_vars = self._private_vars_transform_restore_stack.pop() py_utils.Transform( lambda c: c._UndoTransformVarsInternal(), # pylint: disable=protected-access self.children)
def _EvalOnce(self, sess=None, path=''): """Eval a single checkpoint.""" with self._cluster: # Attempt to restore the checkpoint self._checkpointer.RestoreFromPath(checkpoint_path=path) # Save any additional information to disk before evaluation. if self._eval_type == 'train': self._task.Export(path) global_step = self._model.global_step.numpy() if global_step < self._task.params.eval.start_eval_after: return if self._task.input.params.resettable: tf.logging.info('Resetting input_generator.') self._task.input_generator.Reset() metrics_dict = None num_samples_metric = None samples_per_summary = self._task.params.eval.samples_per_summary if samples_per_summary == 0: assert self._task.input.params.resettable while (samples_per_summary == 0 or metrics_dict is None or num_samples_metric.total_value < samples_per_summary): try: # Evaler calls FProp multiple times for each checkpoint. Multiple # summaries at the same step is often confusing. Instead, models # should update eval_metrics and generate aggregate summaries. Other # types of summaries (images, audio etc.) will be generated for the # first batch only. eval_fn = ( self._eval_fn_with_summary if metrics_dict is None else self._eval_fn) eval_metrics = eval_fn() if metrics_dict is None: metrics_dict = { name: metrics.AverageMetric() for name in eval_metrics } num_samples_metric = metrics_dict['num_samples_in_batch'] eval_metrics = py_utils.Transform(lambda x: x.numpy(), eval_metrics) for name, (value, weight) in eval_metrics.items(): metrics_dict[name].Update(value, weight) tf.logging.info('Total examples done: %d/%d', num_samples_metric.total_value, samples_per_summary) except tf.errors.OutOfRangeError: if not self._task.input.params.resettable: raise break if metrics_dict is None: metrics_dict = {} # Replace average values with total values for certain metrics. if 'num_predictions' in metrics_dict: metrics_dict['num_predictions'].total_weight = 1.0 if 'num_words' in metrics_dict: metrics_dict['num_words'].total_weight = 1.0 msg = 'step:%6d' % global_step with self._summary_writer.as_default(): tf.compat.v2.summary.scalar( 'total_samples', num_samples_metric.total_value, step=global_step) for key, metric in sorted(metrics_dict.items()): msg += ' %s:%.8g' % (key, metric.value) tf.compat.v2.summary.scalar(key, metric.value, step=global_step) self._summary_writer.flush() self._SetStatusMessage(msg)
def FProp(self, theta, *args): """Run multiple cells in different devices in a pipelining manner. Args: theta: A NestedMap object containing weights' values of this layer and its children layers. *args: Non-keyworded variable length argument list of input tensors. Returns: A list of output tensors """ # TODO(huangyp): handle optional None inputs. p = self.params if self.do_eval: outputs = copy.copy(args) for (name, l) in self._before_layers + self._cells: outputs = _ToTuple(outputs) outputs = l.FProp(theta[name], *outputs) return outputs num_cells = len(p.cell_tpl) cluster = self.cluster # Compute shapes of input and output tensors. input_shapes = self._get_input_shapes(*args) state_dtype = self._get_state_dtype(*args) state_shapes = self._CalculateOutputShapes(input_shapes) tf.logging.info('state_shapes={}'.format(state_shapes)) def GetCellFn(i): """Get the ith feature extraction layer.""" def CellFn(theta, state0, inputs): """A cell fn is exectued inside of StackedRecurrent.""" del state0 def _FPropInputSetShape(name, t_shape): if t_shape is None: return None inputs[name].set_shape(t_shape.ToTensorShape().as_list()) return inputs[name] if p.nested_map_fprop: # pylint: disable=protected-access fprop_inputs = state_shapes[i]._RecursiveMap(_FPropInputSetShape) # pylint: enable=protected-access else: fprop_inputs = [] for input_idx, input_shape in enumerate(state_shapes[i]): name = 's{}'.format(input_idx) fprop_inputs.append(_FPropInputSetShape(name, input_shape)) with py_utils.RemoveAssertContext(remove=True): with CellFnFPropOpReplacementWrapper(): tf.logging.info('cell {} input {}'.format(i, fprop_inputs)) mb_tensor = inputs[_MICRO_BATCH_STATE_NAME] SetOverWriteGlobalStep(mb_tensor) _, cell = self._cells[i] fprop_inputs = _ToTuple(fprop_inputs) outputs = cell.FProp(theta, *fprop_inputs) if p.nested_map_fprop: assert py_utils.IsCompatible(outputs, state_shapes[i + 1]) state1 = outputs.Filter(lambda x: x is not None) else: state1 = py_utils.NestedMap() outputs = _ToTuple(outputs) assert len(outputs) == len(state_shapes[i + 1]) for output_idx in range(len(outputs)): if outputs[output_idx] is not None: name = 's{}'.format(output_idx) state1[name] = outputs[output_idx] state1[_MICRO_BATCH_STATE_NAME] = mb_tensor return state1, py_utils.NestedMap() return CellFn cell_fns = [] accumulator_layers = [] thetas = [] init_states = [] devices = [] for cell_idx in range(num_cells): cell_name, cell = self._cells[cell_idx] accumulator_layers.append(cell) cell_fns.append(GetCellFn(cell_idx)) thetas.append(theta[cell_name]) def _TfZeros(t_shape): if t_shape is None: return None return tf.zeros(t_shape.ToTensorShape().as_list(), dtype=state_dtype) if p.nested_map_fprop: init_state = py_utils.Transform(_TfZeros, state_shapes[cell_idx + 1]) init_state = init_state.Filter(lambda x: x is not None) else: init_state = py_utils.NestedMap() for output_idx, state in enumerate(state_shapes[cell_idx + 1]): state = _TfZeros(state) if state is not None: name = 's{}'.format(output_idx) init_state[name] = state init_state[_MICRO_BATCH_STATE_NAME] = tf.cast(0, dtype=state_dtype) init_states.append(init_state) devices.append(cluster.WorkerDeviceInModelSplit(cell_idx)) cell_grads = [None] * num_cells cell_outs = [lambda x: x] * num_cells cell_out_grads = [lambda x: x] * num_cells with tf.device(devices[0]): previous = _ToTuple(args) for (name, l) in self._before_layers: previous = l.FProp(theta[name], *previous) previous = _ToTuple(previous) def _StackAndSplit(x): # Split tensors into microbatches. if x is None: return None return tf.stack(tf.split(x, p.num_micro_batches, axis=p.batch_dim)) if p.nested_map_fprop: inputs = py_utils.Transform(_StackAndSplit, previous[0]) inputs = inputs.Filter(lambda x: x is not None) else: inputs = py_utils.NestedMap() for output_idx, output_tensor in enumerate(previous): output_tensor = _StackAndSplit(output_tensor) if output_tensor is not None: name = 's{}'.format(output_idx) inputs[name] = output_tensor gs_tensor = py_utils.GetGlobalStep() inputs[_MICRO_BATCH_STATE_NAME] = tf.stack([ tf.cast(gs_tensor * p.num_micro_batches + t, dtype=state_dtype) for t in range(p.num_micro_batches) ]) tf.logging.info('pipeline input = {}'.format(inputs)) output_state, _ = recurrent.StackedRecurrent( devices=devices, cell_fns=cell_fns, cell_grads=cell_grads, cell_outs=cell_outs, cell_out_grads=cell_out_grads, thetas=thetas, init_states=init_states, inputs=inputs, accumulator_layers=accumulator_layers, unused_acc_state=True) with tf.device(devices[-1]): def _ReshapeRetVal(name, t_shape): """Restore shape for tensors in microbatches.""" if t_shape is None: return None output_tensor = output_state[name] if p.batch_dim != 0: perm = list(range(1, p.batch_dim + 1)) + [0] perm += list(range(p.batch_dim + 1, t_shape.rank + 1)) output_tensor = tf.transpose(output_tensor, perm=perm) output_shape = t_shape.ToTensorShape().as_list() output_shape[p.batch_dim] *= p.num_micro_batches output_tensor = tf.reshape(output_tensor, output_shape) return output_tensor # Construct the final return values from output_state. if p.nested_map_fprop: # pylint: disable=protected-access output_tensors = state_shapes[-1]._RecursiveMap(_ReshapeRetVal) # pylint: enable=protected-access else: output_tensors = [] for output_idx, state_shape in enumerate(state_shapes[-1]): output_name = 's{}'.format(output_idx) output_tensor = _ReshapeRetVal(output_name, state_shape) output_tensors.append(output_tensor) if len(output_tensors) == 1: output_tensors = output_tensors[0] else: output_tensors = tuple(output_tensors) tf.logging.info('pipeline output = {}'.format(output_tensors)) return output_tensors
def vars(self): """Returns variables of this layer and its children in a `.NestedMap`.""" ret = py_utils.Transform(lambda x: x.vars, self.children) for k in self._private_vars.keys(): ret[k] = self._private_vars[k] return ret