def _BeamSearchDecode(self, input_batch): p = self.params with tf.name_scope('fprop'), tf.name_scope(p.name): encoder_outputs = self.enc.FPropDefaultTheta(input_batch.src) encoder_outputs = self.dec.AddExtraDecodingInfo( encoder_outputs, input_batch.tgt) decoder_outs = self.dec.BeamSearchDecode(encoder_outputs) return self._ProcessBeamSearchDecodeOut(input_batch, encoder_outputs, decoder_outs)
def FProp(self, theta, *args): r"""Applies lambda(x, \*kwargs) for every non-None arg.""" del theta p = self.params with tf.name_scope(p.name): ret = [None if x is None else p.fn(x, **p.kwargs) for x in args] return tuple(ret) if len(ret) > 1 else ret[0]
def FProp(self, theta, *args): """FProp through multiple devices in the split. Args: theta: A NestedMap object containing weights' values of this layer and its children layers. *args: A tuple of Tensors (one or more). Every tensor's first dimension is the same (the batch dimension). Returns: The sub layer's output. """ p = self.params with tf.name_scope(p.name): assert all(isinstance(x, tf.Tensor) for x in args) cluster = self.cluster num = cluster.num_devices_per_split if num == 1: return self.sub.FProp(theta.sub, *args) inps = py_utils.SplitRecursively(list(args), num, axis=0) outs = [] for i, xs in enumerate(inps): device = cluster.WorkerDeviceInModelSplit(i) tf.logging.info('%d on device %s', i, device) with tf.device(device): ys = self.sub.FProp(theta.sub, *xs) if isinstance(ys, tuple): outs += [list(ys)] else: outs += [ys] # ys is a single tensor ret = py_utils.ConcatRecursively(outs, axis=0) if isinstance(ret, list): return tuple(ret) else: return ret # ys is a single tensor
def FProp(self, theta, *args): p = self.params with tf.name_scope(p.name): args = _ToTuple(self.body.FProp(theta.body, *args)) for fetch in p.fetches: args += (self.body.GetDescendant(fetch).activation, ) return args
def Step(recurrent_theta, state0, inputs): """Computes one decoder step.""" del inputs with tf.name_scope('single_sampler_step'): # Compute logits and states. bs_result, bs_state1 = pre_step_callback( recurrent_theta.theta, recurrent_theta.encoder_outputs, tf.expand_dims(state0.ids, 1), # [batch, 1]. state0.bs_state, num_hyps_per_beam=1) batch = tf.shape(bs_result.log_probs)[0] state1 = py_utils.NestedMap(timestep=state0.timestep + 1) state1.logits = bs_result.log_probs # Sample ids from logits. [batch]. state1.ids = tf.reshape( tf.random.stateless_categorical( state1.logits / p.temperature, num_samples=1, seed=tf.stack( [recurrent_theta.random_seed, state0.timestep]), dtype=state0.ids.dtype, name='sample_next_id'), [batch]) if 'is_last_chunk' in bs_result and p.target_eoc_id >= 0: state1.ids = tf.where( tf.math.logical_and( bs_result.is_last_chunk, tf.equal(state1.ids, p.target_eoc_id)), tf.fill(tf.shape(state1.ids), p.target_eos_id), state1.ids) state1.bs_state = post_step_callback( recurrent_theta.theta, recurrent_theta.encoder_outputs, state1.ids, bs_state1) return state1, py_utils.NestedMap()
def fast_gather(values, ids, ids_size, max_value=None, axis=0, batch_major_state=True): """Fast implementation of gather on TPUs. Args: values: Values to gather from. ids: ids (rows to gather) ids_size: id space size. max_value: Optional hint on maximum value for int32 that allows to speed up the gather operation. axis: axis to gather on. Defaults to 0 (rows). batch_major_state: Whether the values to gather from use batch major or not. Defaults to True. Returns: Gathered values. Raises: Value error if values is type int64. """ values = tf.convert_to_tensor(values) ids = tf.convert_to_tensor(ids) with tf.name_scope("fast_gather"): return _Gatherer(ids, ids_size)(values, max_value=max_value, axis=axis, batch_major_state=batch_major_state)
def PrepareExternalInputs(self, theta, external_inputs): """Prepares external inputs for each sub-step. The external_inputs parameter of this method is processed by the external_inputs of each sub-step, then processed by the sub-step's PrepareExternalInputs method. Args: theta: variables used by sub-steps. external_inputs: A NestedMap of [n_batch, ...] tensors. Returns: A NestedMap of prepared inputs, where the keys are the names of each sub-step. """ graph_tensors = builder_layers.GraphTensors() graph_tensors.StoreTensor('external_inputs', external_inputs) prepared_inputs = py_utils.NestedMap() with tf.name_scope(self.params.name): for seq in self._seq: if seq.external_signature: template = py_utils.NestedMap( inputs=seq.external_signature.inputs) packed = template.Transform(graph_tensors.GetTensor) seq_external_inputs = packed.inputs[0] prepared_inputs[seq.name] = seq.step.PrepareExternalInputs( theta[seq.name], seq_external_inputs) else: prepared_inputs[seq.name] = py_utils.NestedMap() return prepared_inputs
def FProp(self, theta, *args): p = self.params # Collects all variable key and values into sets. theta_stack = _MaybeStackExtraTheta(theta.body, self.body.vars, p.repeat) def _ArgsToState(arg_list): """Returns a NestedMap from a list of FProp args.""" state = py_utils.NestedMap() # Maintains a mapping from arg_idx to tensor. states cannot contains # None tensors. for idx in range(len(args)): if arg_list[idx] is not None: state['_s{}'.format(idx)] = arg_list[idx] return state def _StateToArgs(state): """Returns a list of FProp args from a NestedMap.""" arg_list = [] for idx in range(len(args)): attr = '_s{}'.format(idx) arg_list.append(state[attr] if attr in state else None) if arg_list[-1] is not None: arg_list[-1].set_shape(args[idx].shape) return arg_list def _CellFn(unused_theta, state0, theta_i): """Recurrent cell function wrapper of body.FProp.""" # Retrieves fprop arguments from state and sets shapes. frop_inputs = _StateToArgs(state0) # Sets shapes for theta_i as well. for dst, src in zip(theta_i.Flatten(), theta_stack.Flatten()): if src is not None: dst.set_shape(tf.TensorShape(src.shape.as_list()[1:])) # Runs the actual body.FProp frop_outputs = self.body.FProp(theta_i, *frop_inputs) frop_outputs = _ToTuple(frop_outputs) assert len(frop_outputs) == len(frop_inputs) # Passes fprop outputs to the next layer through state. state1 = _ArgsToState(frop_outputs) return state1, py_utils.NestedMap() with tf.name_scope(p.name): # Add FProp arg list to state0. state0 = _ArgsToState(args) # Runs body.FProp k times using Recurrent where k = dim 0 of var_nmap. _, state1 = recurrent.Recurrent( theta=py_utils.NestedMap(), state0=state0, inputs=theta_stack, # Pass cell_fn theta through inputs. cell_fn=_CellFn) # Retrieves fprop outputs from state1 and sets shapes. output_tensors = _StateToArgs(state1) return output_tensors[0] if len(args) == 1 else tuple( output_tensors)
def FProp(self, theta, x): tf.logging.vlog(1, 'layer %s', self.params.name) with tf.name_scope(self.params.name): for (name, ch) in self._seq: th = theta[name] tf.logging.vlog(1, ' call %s %s %s', ch.params.name, ch, x) x = ch.FProp(th, x) return x
def FProp(self, theta, input_batch): """Encodes source as represented by `inputs` and `paddings`. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. input_batch: A `.NestedMap` with fields: - ids: The inputs tensor. It is expected to be of shape [batch, time]. - paddings: The paddings tensor. Expected shape [batch, time]. Returns: A NestedMap containing: - encoded: The encoded features, a tensor of shape [time, batch, depth] - padding: of shape [time, batch] - segment_id: [time, batch] if packed inputs are supported by the model (and all layers), or None otherwise. """ p = self.params src_segment_id = None with tf.name_scope(p.name): # Now the rnn layers. inputs = py_utils.with_dependencies([ py_utils.assert_shape_match(tf.shape(input_batch.ids), [-1, -1]), py_utils.assert_shape_match(tf.shape(input_batch.ids), tf.shape(input_batch.paddings)) ], tf.transpose(input_batch.ids)) paddings = tf.expand_dims(tf.transpose(input_batch.paddings), 2) xs = self.emb.EmbLookup(theta.emb, inputs) xs = self.ApplyClipping(theta, xs) self._emb_out = xs ps = paddings # When cc_schedule is specified, make sure lstm_tpl is QuantizedLSTMCell # with the same cc_schedule so that the RNN layer output is within # clipping range. xs = self.rnn[0].FProp(theta.rnn[0], xs, ps) xs = self.dropout.FProp(theta.dropout, xs) for i in range(1, p.num_lstm_layers): layer = self.rnn[i] ys, _ = layer.FProp(theta.rnn[i], xs, ps) ys = self.dropout.FProp(theta.dropout, ys) if hasattr(layer.params, 'cell'): layer_params = layer.params.cell else: layer_params = layer.params if layer_params.num_input_nodes == layer_params.num_output_nodes: xs += ys # Residual skip xs = self.ApplyClipping(theta, xs) else: # When cc_schedule is specified, make sure lstm_tpl is # QuantizedLSTMCell with the same cc_schedule so that the RNN layer # output is within clipping range. xs = ys return py_utils.NestedMap(encoded=xs, padding=tf.squeeze(ps, [2]), segment_id=src_segment_id)
def FProp(self, theta, current_step): p = self.params assert p.total_steps > 0 assert p.initial_value > p.final_value with tf.name_scope(p.name): decay_gap = p.initial_value - p.final_value return p.final_value + 0.5 * decay_gap * (1 + tf.cos( math.pi * tf.minimum(1.0, tf.cast(current_step, tf.float32) / p.total_steps)))
def FProp(self, theta, inputs, paddings): """Apply convolution to inputs. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. inputs: The inputs tensor. It is expected to be of shape [batch, time, frequency, channel]. The time dimension corresponds to the height dimension as in images and the frequency dimension corresponds to the width dimension as in images. paddings: The paddings tensor, expected to be of shape [batch, time]. Returns: outputs, out_paddings pair. """ p = self.params with tf.name_scope(p.name): inputs = py_utils.with_dependencies([ py_utils.assert_shape_match(tf.shape(paddings), [-1, -1]), py_utils.assert_shape_match( tf.shape(inputs), tf.concat([ tf.shape(paddings), [-1, symbolic.ToStatic(self.input_channels)] ], 0)) ], inputs) def _ApplyPadding(tensor_in, padding_in): padding_expanded = tf.expand_dims( tf.expand_dims(padding_in, -1), -1) return tensor_in * (1.0 - padding_expanded) # Zeroing out padded inputs. inputs = _ApplyPadding(inputs, paddings) # Apply conv on 'inputs'. out = self._ApplyConv(theta, inputs) if p.partial_conv: out = self._RescaleBoundary(out, paddings) # NOTE: this may be slightly inaccurate when p.dilation_rate[0] > 1. # But there's likely no real problems. Trying to set it gives an error: # pooling with SAME padding is not implemented for dilation_rate > 1. # NOTE: we use window=p.filter_stride[0] to be compatible with legacy # implementation. Consider updating it to be the actual shape. conv_padding = ComputeConvOutputPadding(paddings, window=p.filter_stride[0], stride=p.filter_stride[0]) # Assuming padded nodes will be properly zero-ed out if necessary by # sub-sequent layers. # out = _ApplyPadding(out, conv_padding) out = py_utils.HasShape( out, symbolic.ToStatic(self.OutShape(tf.shape(inputs)))) return out, conv_padding
def FProp(self, theta, *args): r"""Applies a function (p.fn) on args. Args: theta: Unused. *args: A tuple of Tensors (one or more). Returns: fn(\*args). """ with tf.name_scope(self.params.name): return self.params.fn(*args)
def FProp(self, theta, input_batch): p = self.params with tf.name_scope(p.name): inputs = py_utils.with_dependencies([ py_utils.assert_shape_match(tf.shape(input_batch.ids), [-1, -1]), py_utils.assert_shape_match(tf.shape(input_batch.ids), tf.shape(input_batch.paddings)) ], tf.transpose(input_batch.ids)) paddings = tf.expand_dims(tf.transpose(input_batch.paddings), 2) if p.packed_input: src_segment_id = tf.expand_dims( tf.transpose(input_batch.segment_ids), 2) else: src_segment_id = None xs = self.emb.EmbLookup(theta.emb, inputs) xs = self.ApplyClipping(theta, xs) summary_utils.histogram('input_emb', xs) xs = self.dropout.FProp(theta.dropout, xs) ps = paddings # Now the rnn layers. outputs_list = [] for i in range(0, p.num_lstm_layers): layer = self.rnn[i] ys = layer.FProp(theta.rnn[i], xs, ps, segment_id=src_segment_id) ys = self.dropout.FProp(theta.dropout, ys) if i >= p.residual_start: xs += ys # Residual skip xs = self.ApplyClipping(theta, xs) else: xs = ys outputs_list.append(xs) summary_utils.histogram('layer_out_%s' % i, xs) if p.is_transparent: xs = self.transparent_merger.FProp(theta.transparent_merger, outputs_list) if p.lstm_cell_size * 2 != p.encoder_out_dim: # Project to the right depth. xs = self.final_proj.FProp(theta.final_proj, xs, ps) summary_utils.histogram('final_proj_out', xs) if src_segment_id is not None: src_segment_id = tf.squeeze(src_segment_id, [2]) return py_utils.NestedMap(encoded=xs, padding=tf.squeeze(ps, [2]), segment_id=src_segment_id)
def FProp(self, theta, inputs): """Adds bias to inputs. Args: theta: A NestedMap object containing weights' values of this layer and its children layers. inputs: The inputs tensor. Shaped [..., dims]. Returns: Inputs plus bias. """ with tf.name_scope(self.params.name): return inputs + theta.b
def _Traverse(layer): """Adds accumulators to layer and its descendant layers.""" if isinstance(layer, (list, tuple)): for layer_i in layer: _Traverse(layer_i) return with tf.name_scope(layer.params.name): for cost_metric_name in COST_METRICS: dtype = COST_METRICS[cost_metric_name] layer.RegisterAccumulator( cost_metric_name, bn_layers.AddingAccumulator(shape=[], dtype=dtype)) for _, child in sorted(layer.children.items()): _Traverse(child)
def CollectVarHistogram(vs_gs): """Adds histogram summaries for variables and gradients.""" for name, (var, grad) in vs_gs.FlattenItems(): name = py_utils.SanitizeScopeKey(name) with tf.device(var.device), tf.name_scope(name + '/summary'): if isinstance(grad, tf.IndexedSlices): var = tf.gather(var, grad.indices) grad = grad.values if var.dtype.is_complex: var = tf.abs(var) grad = tf.abs(grad) histogram('var_hist/' + name, var) histogram('grad_hist/' + name, grad)
def FProp(self, theta, *args): p = self.params with tf.name_scope(p.name): # Computes sub layers in parallel. outputs = [] for (name, ch) in self._seq: th = theta[name] out = ch.FProp(th, *args) if isinstance(out, (list, tuple)): outputs.append(tuple(out)) else: outputs.append((out, )) rets = p.merge(outputs) return rets if len(rets) > 1 else rets[0]
def FProp(self, theta, input_batch, state0=None): p = self.params src_segment_id = None with tf.name_scope(p.name): # Reshape to [t, b] inputs = py_utils.with_dependencies([ py_utils.assert_shape_match(tf.shape(input_batch.ids), [-1, -1]), py_utils.assert_shape_match(tf.shape(input_batch.ids), tf.shape(input_batch.paddings)) ], tf.transpose(input_batch.ids)) paddings = tf.expand_dims(tf.transpose(input_batch.paddings), 2) # Setup streaming states. if not state0: state0 = self.zero_state(theta, tf.shape(inputs)[1]) state1 = py_utils.NestedMap(rnn=[None] * p.num_lstm_layers) xs = self.emb.EmbLookup(theta.emb, inputs) xs = self.ApplyClipping(theta, xs) summary_utils.histogram('input_emb', xs) xs = self.dropout.FProp(theta.dropout, xs) ps = paddings # Now the rnn layers. outputs_list = [] for i in range(0, p.num_lstm_layers): layer = self.rnn[i] ys, state1.rnn[i] = layer.FProp(theta.rnn[i], xs, ps, state0=state0.rnn[i]) ys = self.dropout.FProp(theta.dropout, ys) if i >= p.residual_start: xs += ys # Residual skip xs = self.ApplyClipping(theta, xs) else: xs = ys outputs_list.append(xs) summary_utils.histogram('layer_out_%s' % i, xs) if p.is_transparent: xs = self.transparent_merger.FProp(theta.transparent_merger, outputs_list) return py_utils.NestedMap(encoded=xs, padding=tf.squeeze(ps, [2]), segment_id=src_segment_id, state=state1)
def FProp(self, theta, *args): p = self.params with tf.name_scope(p.name) as name_scope: for i, arg in enumerate(args): if not isinstance(arg, tf.Tensor): tf.logging.info( 'FProp non-Tensor input in {}: arg_{} arg = {}'.format( name_scope, i, arg)) else: tf.logging.info( 'FProp inputs in {}: arg_{} shape = {} dtype = {}'. format(name_scope, i, arg.shape, arg.dtype.name)) if len(args) == 1: return args[0] else: return args
def FProp(self, theta, inputs, paddings=None): """Apply batch normalization. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. inputs: The inputs tensor. Shaped [..., dim]. paddings: The paddings tensor. Shaped [..., 1], with the same rank as the input tensor. Returns: Output after applying batch normalization, with the same shape as 'inputs'. """ p = self.params if paddings is None: paddings = self._GetDefaultPaddings(inputs) with tf.name_scope(p.name): norm_mean, norm_variance, beta, gamma = self.ComputeAndUpdateMoments( theta, inputs, paddings) with tf.control_dependencies([ py_utils.assert_greater_equal(norm_variance, tf.zeros_like(norm_variance)), py_utils.assert_shape_match([tf.shape(inputs)[-1]], tf.shape(norm_mean)), py_utils.assert_shape_match([tf.shape(inputs)[-1]], tf.shape(norm_variance)), ]): if p.use_fused_batch_norm_for_eval and self.do_eval: bn_output, _, _ = nn.fused_batch_norm( inputs, gamma, beta, norm_mean, norm_variance, self._epsilon, is_training=False) else: bn_output = tf.nn.batch_normalization(inputs, norm_mean, norm_variance, beta, gamma, self._epsilon) if p.set_padded_output_to_zero: bn_output *= 1.0 - paddings return bn_output
def ZeroState(self, theta, prepared_inputs, batch_size): """Creates a zero state NestedMap for this step. Args: theta: variables used by sub-steps. prepared_inputs: Output from a call to PrepareExternalInputs. batch_size: The number of items in the batch that FProp will process. Returns: A NestedMap of ZeroState results for each sub-step. """ state0 = py_utils.NestedMap() with tf.name_scope(self.params.name): for seq in self._seq: state0[seq.name] = seq.step.ZeroState( theta[seq.name], prepared_inputs[seq.name], batch_size) return state0
def TraverseLayer(layer, fn): """Traverses the layer tree and invokes fn(node) on each node. Args: layer: a BaseLayer. fn: a function of (layer, layer_theta) -> None. """ if isinstance(layer, (list, tuple)): for layer_i in layer: TraverseLayer(layer_i, fn) return with tf.name_scope(layer.params.name): fn(layer) # Traverse all children in alphabetical order. for _, child in sorted(layer.children.items()): TraverseLayer(child, fn)
def FProp(self, theta, prepared_inputs, step_inputs, padding, state0): """A single inference step for this step graph. Args: theta: variables used by sub-steps. prepared_inputs: A NestedMap containing external_inputs that were pre-processed by the PrepareExternalInputs method of each sub-step. The keys are the names of the sub-steps. step_inputs: A NestedMap of [batch, ...] tensors. The structure of this depends on the graph implementation. padding: A 0/1 float tensor of shape [batch_size]; 1.0 means that this batch element is empty in this step. state0: A NestedMap of state variables produced by either ZeroState or a previous invocation of this FProp step. The keys are the names of the sub-steps. Returns: (output, state1), both of which are NestedMaps. output is implementation-dependent and is defined by the output_signature parameter. state1 is a NestedMap where the keys are names of sub-steps and the values are state outputs from their FProp methods. """ p = self.params graph_tensors = builder_layers.GraphTensors() graph_tensors.StoreTensor('prepared_inputs', prepared_inputs) graph_tensors.StoreTensor('step_inputs', step_inputs) state1 = py_utils.NestedMap() with tf.name_scope(p.name): for seq in self._seq: tf.logging.vlog(1, 'GraphStep: call %s', seq.name) external = None if seq.external_signature: external = prepared_inputs[seq.name] template = py_utils.NestedMap(inputs=seq.signature.inputs) packed = template.Transform(graph_tensors.GetTensor) input_args = packed.inputs[0] out, seq_state1 = seq.step.FProp(theta[seq.name], external, input_args, padding, state0[seq.name]) graph_tensors.StoreTensor(seq.signature.outputs[0], out) state1[seq.name] = seq_state1 template = py_utils.NestedMap(inputs=self.output_signature.inputs) output_tensors = template.Transform(graph_tensors.GetTensor).inputs[0] return output_tensors, state1
def FProp(self, theta, inputs, *args): p = self.params with tf.name_scope(p.name) as scope: expert_dist = self._GetExpertDist(theta, inputs, *args) if not self.do_eval: summary_utils.histogram('soft_cond_{}'.format(scope), expert_dist) # Excludes non-variable extra_theta like global_step. var_set = set([key for key, _ in self.body.vars.FlattenItems()]) values = [] for key, value in theta.body.FlattenItems(): if key in var_set and value is not None: # Weighted average for all variables created in the body layer. value = tf.einsum('i,i...->...', expert_dist, value) values.append(value) weighted_theta = theta.body.Pack(values) return self.body.FProp(weighted_theta, inputs, *args)
def FProp(self, theta, current_step): p = self.params with tf.name_scope(p.name): steps = self._best_step best_step = steps[0] last_step = steps[1] ref_step = tf.maximum(self._ref_step, best_step) f = self._cur_factor # Decay if no improvement within window. new_factor = tf.where(last_step - ref_step < p.window, f, tf.maximum(p.min_factor, f * p.decay)) # Update ref_step if we decayed. new_step = tf.where(tf.equal(new_factor, f), ref_step, last_step) update_step = tf.assign(self._ref_step, new_step) with tf.control_dependencies([update_step]): return tf.assign(self._cur_factor, new_factor)
def FProp(self, theta, *args): p = self.params with tf.name_scope(p.name): tf.logging.vlog(1, 'layer %s', self.params.name) if p.repeat <= 1: for (name, ch) in self._seq: th = theta[name] args = _ToTuple(args) tf.logging.vlog(1, 'SequentialLayer: call %s %s %d %s', ch.params.name, ch, len(args), str(args)) args = ch.FProp(th, *args) else: for (ch, th) in zip(self.rep, theta.rep): args = _ToTuple(args) tf.logging.vlog(1, ' call %s %s %d %s', ch.params.name, ch, len(args), str(args)) args = ch.FProp(th, *args) args = _ToTuple(args) return args[0] if len(args) == 1 else args
def FProp(self, theta, inputs): """Apply projection to inputs. Args: theta: A NestedMap object containing weights' values of this layer and its children layers. inputs: The inputs tensor. Shaped [..., input_dims]. Returns: Projected inputs. """ p = self.params with tf.name_scope(p.name): computation_cost.Add( self, 'flops', tf.reduce_prod(tf.cast(tf.shape(inputs)[:-1], tf.int64)) * tf.cast(symbolic.ToTensor(p.input_dims * p.output_dims), tf.int64) * 2) return py_utils.ProjectLastDim(inputs, theta.w, p.input_dims, p.output_dims)
def reorder_tensor(reorder_mode, values, num_shards, shard_size, max_value=None, axis=0): """Reorder tensor based on the mode passed in. This method reorders rows or cols (based on `axis`) of the tensor passed in from one sharding mode to another sharding mode. This method uses matmul for reordering to be efficient on TPUs. Args: reorder_mode: Either mod_to_div or div_to_mod values: Tensor to reorder num_shards: Number of shards. shard_size: Size of each shard. max_value: If dtype=tf.int32, and we know maximum of the values, we can efficiently implement it as matmuls. axis: axis to gather on. Defaults to 0 (rows). Returns: A tensor of same shape as values but rows (or first axis) reordered. """ values = tf.convert_to_tensor(values) with tf.name_scope("reorder_tensor_" + reorder_mode): num_ids = num_shards * shard_size # Elements to gather. seq_ids = tf.range(num_ids) if reorder_mode == "mod_to_div": local_ids = seq_ids // shard_size shard_ids = seq_ids % shard_size ids = local_ids + shard_ids * num_shards elif reorder_mode == "div_to_mod": shard_ids = seq_ids % num_shards local_ids = seq_ids // num_shards ids = local_ids + shard_ids * shard_size else: raise NotImplementedError( "Reorder mode: {} not implemented.".format(reorder_mode)) return fast_gather(values, ids, num_ids, max_value, axis=axis)
def GenerateStepSeedPair(p, unused_global_step=None, op_seed=None): """Override py_utils.GenerateStepSeedPair to use GetOverWriteGlobalStep.""" seed_dtype = tf.int32 if py_utils.use_tpu() else tf.int64 if p.is_inference and p.random_seed is None: # Unlike tf.random*, stateless random ops are completely determined by the # passed-in seeds. This means at inference time the same inputs will produce # the same outputs, even if the model is supposed to have randomness such as # dropout during inference. We inject additional randomness only during # inference if the graph is exported with random_seed=None as a workaround. return tf.random.uniform([2], maxval=seed_dtype.max, dtype=seed_dtype) with tf.name_scope('op_seed') as scope: global_step = tf.cast(GetOverWriteGlobalStep(), seed_dtype) step_seed = tf.cast(py_utils.GenerateSeedFromName(scope), seed_dtype) seeds = tf.stack([global_step, step_seed]) if p.random_seed is not None: seeds += p.random_seed if op_seed is not None: seeds += op_seed return seeds