def _Moments(inputs, mask, enable_cross_replica_sum_on_tpu=False): """Computes mean and variance over the valid data points in inputs.""" inputs = py_utils.with_dependencies([ py_utils.assert_equal(tf.rank(inputs), tf.rank(mask)), py_utils.assert_greater_equal(mask, tf.zeros_like(mask)), ], inputs) rank = tf.rank(mask) reduce_over_dims = tf.range(0, rank - 1) sum_v = tf.reduce_sum(inputs * tf.cast(mask, inputs.dtype), reduce_over_dims) count_v = tf.reduce_sum(mask, reduce_over_dims) # Input shape is guaranteed to be a multiple of mask shape because the # inputs * mask op above was successfully broadcasted. mask_multiplier = tf.shape(inputs)[:-1] // tf.shape(mask)[:-1] count_v *= tf.cast(tf.reduce_prod(mask_multiplier), count_v.dtype) if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu: sum_v = tf.tpu.cross_replica_sum(sum_v) count_v = tf.tpu.cross_replica_sum(count_v) count_v = tf.maximum(count_v, 1.0) mean = sum_v / count_v sum_vv = tf.reduce_sum((inputs - mean) * (inputs - mean) * mask, reduce_over_dims) if py_utils.use_tpu() and enable_cross_replica_sum_on_tpu: sum_vv = tf.tpu.cross_replica_sum(sum_vv) variance = py_utils.with_dependencies([ py_utils.assert_greater_equal(sum_vv, tf.zeros_like(sum_vv)), ], sum_vv / count_v) return mean, variance
def FProp(self, theta, input_batch): """Encodes source as represented by `inputs` and `paddings`. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. input_batch: A `.NestedMap` with fields: - ids: The inputs tensor. It is expected to be of shape [batch, time]. - paddings: The paddings tensor. Expected shape [batch, time]. Returns: A NestedMap containing: - encoded: The encoded features, a tensor of shape [time, batch, depth] - padding: of shape [time, batch] - segment_id: [time, batch] if packed inputs are supported by the model (and all layers), or None otherwise. """ p = self.params src_segment_id = None with tf.name_scope(p.name): # Now the rnn layers. inputs = py_utils.with_dependencies([ py_utils.assert_shape_match(tf.shape(input_batch.ids), [-1, -1]), py_utils.assert_shape_match(tf.shape(input_batch.ids), tf.shape(input_batch.paddings)) ], tf.transpose(input_batch.ids)) paddings = tf.expand_dims(tf.transpose(input_batch.paddings), 2) xs = self.emb.EmbLookup(theta.emb, inputs) xs = self.ApplyClipping(theta, xs) self._emb_out = xs ps = paddings # When cc_schedule is specified, make sure lstm_tpl is QuantizedLSTMCell # with the same cc_schedule so that the RNN layer output is within # clipping range. xs = self.rnn[0].FProp(theta.rnn[0], xs, ps) xs = self.dropout.FProp(theta.dropout, xs) for i in range(1, p.num_lstm_layers): layer = self.rnn[i] ys, _ = layer.FProp(theta.rnn[i], xs, ps) ys = self.dropout.FProp(theta.dropout, ys) if hasattr(layer.params, 'cell'): layer_params = layer.params.cell else: layer_params = layer.params if layer_params.num_input_nodes == layer_params.num_output_nodes: xs += ys # Residual skip xs = self.ApplyClipping(theta, xs) else: # When cc_schedule is specified, make sure lstm_tpl is # QuantizedLSTMCell with the same cc_schedule so that the RNN layer # output is within clipping range. xs = ys return py_utils.NestedMap(encoded=xs, padding=tf.squeeze(ps, [2]), segment_id=src_segment_id)
def IdsToStrings(self, ids, lens): """Takes integer matrices and returns vectors of strings.""" ids = py_utils.with_dependencies( [py_utils.assert_same_dim0([ids, lens])], ids) return tf.map_fn( lambda inputs: self._wpm_encoder.Decode(inputs[0][:inputs[1]]), (ids, lens), dtype=tf.string, parallel_iterations=30, back_prop=False)
def MakeCausalPadding(seq_len, block_size, left_context, right_context): """Makes the causal padding tensor for a full sequence. Args: seq_len: int or scalar int tensor. Sequence length. block_size: int. Number of time frames in a block. left_context: int. Left context size. right_context: int. Right context size. Returns: A tensor of [num_blocks, block_size, context_size] taking values in {0, 1}, where context_size = block_size + (left_context - 1) + right_context. Element b, i, j is zero if in the b-th block, the i-th frame can access the j-th frame in the context. """ seq_len = py_utils.with_dependencies([ py_utils.assert_greater_equal( seq_len, 1, message='seq_len must be at least 1') ], seq_len) num_blocks = (seq_len + block_size - 1) // block_size context_size = block_size + (left_context - 1) + right_context # [num_blocks, block_size]: source positions in the original sequence. src_positions = tf.reshape(tf.range(num_blocks * block_size), [num_blocks, block_size]) # [num_blocks,]: source positions at the start of each block. block_start_positions = tf.range(0, num_blocks * block_size, block_size) # [context_size]: positions relative to the block start. relative_context_positions = tf.range(context_size) - (left_context - 1) # [num_blocks, context_size]: target positions in the original sequence. tgt_positions = (block_start_positions[:, tf.newaxis] + relative_context_positions[tf.newaxis, :]) # [num_blocks, block_size, context_size]: position differences between source- # target pairs. position_diff = src_positions[:, :, tf.newaxis] - tgt_positions[:, tf.newaxis, :] # [num_blocks, block_size, context_size]: if attention is allowed between # source-target pairs. valid_atten = tf.math.logical_and(-right_context <= position_diff, position_diff < left_context) # [num_blocks, block_size]: if the source position is valid, not padded. valid_src = src_positions < seq_len # [num_blocks, context_size]: if the target position is valid, not padded. valid_tgt = tf.math.logical_and(0 <= tgt_positions, tgt_positions < seq_len) valid_atten &= tf.math.logical_and(valid_src[:, :, tf.newaxis], valid_tgt[:, tf.newaxis, :]) padding = 1.0 - tf.cast(valid_atten, dtype=tf.float32) return padding
def FProp(self, theta, inputs, paddings): """Apply convolution to inputs. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. inputs: The inputs tensor. It is expected to be of shape [batch, time, frequency, channel]. The time dimension corresponds to the height dimension as in images and the frequency dimension corresponds to the width dimension as in images. paddings: The paddings tensor, expected to be of shape [batch, time]. Returns: outputs, out_paddings pair. """ p = self.params with tf.name_scope(p.name): inputs = py_utils.with_dependencies([ py_utils.assert_shape_match(tf.shape(paddings), [-1, -1]), py_utils.assert_shape_match( tf.shape(inputs), tf.concat([ tf.shape(paddings), [-1, symbolic.ToStatic(self.input_channels)] ], 0)) ], inputs) def _ApplyPadding(tensor_in, padding_in): padding_expanded = tf.expand_dims( tf.expand_dims(padding_in, -1), -1) return tensor_in * (1.0 - padding_expanded) # Zeroing out padded inputs. inputs = _ApplyPadding(inputs, paddings) # Apply conv on 'inputs'. out = self._ApplyConv(theta, inputs) if p.partial_conv: out = self._RescaleBoundary(out, paddings) # NOTE: this may be slightly inaccurate when p.dilation_rate[0] > 1. # But there's likely no real problems. Trying to set it gives an error: # pooling with SAME padding is not implemented for dilation_rate > 1. # NOTE: we use window=p.filter_stride[0] to be compatible with legacy # implementation. Consider updating it to be the actual shape. conv_padding = ComputeConvOutputPadding(paddings, window=p.filter_stride[0], stride=p.filter_stride[0]) # Assuming padded nodes will be properly zero-ed out if necessary by # sub-sequent layers. # out = _ApplyPadding(out, conv_padding) out = py_utils.HasShape( out, symbolic.ToStatic(self.OutShape(tf.shape(inputs)))) return out, conv_padding
def FProp(self, theta, input_batch): p = self.params with tf.name_scope(p.name): inputs = py_utils.with_dependencies([ py_utils.assert_shape_match(tf.shape(input_batch.ids), [-1, -1]), py_utils.assert_shape_match(tf.shape(input_batch.ids), tf.shape(input_batch.paddings)) ], tf.transpose(input_batch.ids)) paddings = tf.expand_dims(tf.transpose(input_batch.paddings), 2) if p.packed_input: src_segment_id = tf.expand_dims( tf.transpose(input_batch.segment_ids), 2) else: src_segment_id = None xs = self.emb.EmbLookup(theta.emb, inputs) xs = self.ApplyClipping(theta, xs) summary_utils.histogram('input_emb', xs) xs = self.dropout.FProp(theta.dropout, xs) ps = paddings # Now the rnn layers. outputs_list = [] for i in range(0, p.num_lstm_layers): layer = self.rnn[i] ys = layer.FProp(theta.rnn[i], xs, ps, segment_id=src_segment_id) ys = self.dropout.FProp(theta.dropout, ys) if i >= p.residual_start: xs += ys # Residual skip xs = self.ApplyClipping(theta, xs) else: xs = ys outputs_list.append(xs) summary_utils.histogram('layer_out_%s' % i, xs) if p.is_transparent: xs = self.transparent_merger.FProp(theta.transparent_merger, outputs_list) if p.lstm_cell_size * 2 != p.encoder_out_dim: # Project to the right depth. xs = self.final_proj.FProp(theta.final_proj, xs, ps) summary_utils.histogram('final_proj_out', xs) if src_segment_id is not None: src_segment_id = tf.squeeze(src_segment_id, [2]) return py_utils.NestedMap(encoded=xs, padding=tf.squeeze(ps, [2]), segment_id=src_segment_id)
def SplitTensors(xs, num_splits): """Splits tensors in `xs` evenly into num_splits along the 1st dimenion. Args: xs: A tuple of tensors. Each tensor's 1st dimension is the same size. num_splits: A python integer. Returns: A tuple of lists of tensors, num elements in the tuple = len(xs). i-th element in each list corresponds to i-th split of each tensor in xs along the first dimension of each tensor. """ # assert first dim of all tensors in xs is equal batch_dims = [tf.shape(x)[0] for x in xs] all_batch_dims = tf.stack(batch_dims) all_batch_dims = py_utils.with_dependencies([ py_utils.assert_equal(all_batch_dims, tf.shape(xs[0])[0], message='first dim of tensors in xs must match'), py_utils.assert_greater_equal( tf.shape(xs[0])[0], num_splits, message='first dim of tensors in xs must be greater than num_splits' ) ], all_batch_dims) splits = ComputeSplits(tf.shape(xs[0])[0], num_splits) # add the above assertion into the compute graph splits = py_utils.with_dependencies([all_batch_dims], splits) split_xs = [ tf.split(axis=0, num_or_size_splits=splits, value=x) for x in xs ] return split_xs
def FProp(self, theta, input_batch, state0=None): p = self.params src_segment_id = None with tf.name_scope(p.name): # Reshape to [t, b] inputs = py_utils.with_dependencies([ py_utils.assert_shape_match(tf.shape(input_batch.ids), [-1, -1]), py_utils.assert_shape_match(tf.shape(input_batch.ids), tf.shape(input_batch.paddings)) ], tf.transpose(input_batch.ids)) paddings = tf.expand_dims(tf.transpose(input_batch.paddings), 2) # Setup streaming states. if not state0: state0 = self.zero_state(theta, tf.shape(inputs)[1]) state1 = py_utils.NestedMap(rnn=[None] * p.num_lstm_layers) xs = self.emb.EmbLookup(theta.emb, inputs) xs = self.ApplyClipping(theta, xs) summary_utils.histogram('input_emb', xs) xs = self.dropout.FProp(theta.dropout, xs) ps = paddings # Now the rnn layers. outputs_list = [] for i in range(0, p.num_lstm_layers): layer = self.rnn[i] ys, state1.rnn[i] = layer.FProp(theta.rnn[i], xs, ps, state0=state0.rnn[i]) ys = self.dropout.FProp(theta.dropout, ys) if i >= p.residual_start: xs += ys # Residual skip xs = self.ApplyClipping(theta, xs) else: xs = ys outputs_list.append(xs) summary_utils.histogram('layer_out_%s' % i, xs) if p.is_transparent: xs = self.transparent_merger.FProp(theta.transparent_merger, outputs_list) return py_utils.NestedMap(encoded=xs, padding=tf.squeeze(ps, [2]), segment_id=src_segment_id, state=state1)
def ApplyBias(): """Bias and update log_probs and consistent.""" def TileForBeamAndFlatten(tensor): tensor = tf.reshape(tensor, [1, -1]) # [1, src_batch] tensor = tf.tile( tensor, [num_hyps_per_beam, 1]) # [num_hyps_per_beam, src_batch] tgt_batch = tf.shape(step_ids)[0] # num_hyps_per_beam*src_batch return tf.reshape(tensor, [tgt_batch]) # Consistent if step_ids == labels from previous step # TODO(navari): Consider updating consistent only if weights > 0. Then # re-evaluate the need for bias_only_if_consistent=True. # Note that prev_label is incorrrect for step 0 but is overridden later prev_label = TileForBeamAndFlatten( tf.gather(labels, tf.maximum(time_step - 1, 0), axis=1)) is_step0 = tf.equal(time_step, 0) local_consistence = tf.math.logical_or( is_step0, tf.equal(prev_label, tf.squeeze(step_ids, 1))) consistent = tf.math.logical_and(states.consistent, local_consistence) # get label, weight slices corresponding to current time_step label = TileForBeamAndFlatten(tf.gather(labels, time_step, axis=1)) weight = TileForBeamAndFlatten(tf.gather(weights, time_step, axis=1)) if p.bias_only_if_consistent: weight = weight * tf.cast(consistent, p.dtype) # convert from dense label to sparse label probs vocab_size = tf.shape(bs_results.log_probs)[1] uncertainty = tf.constant( 1e-10, p.dtype) # avoid 0 probs which may cause issues with log label_probs = tf.one_hot( label, vocab_size, on_value=1 - uncertainty, off_value=uncertainty / tf.cast(vocab_size - 1, p.dtype), dtype=p.dtype) # [tgt_batch, vocab_size] pred_probs = tf.exp(bs_results.log_probs) # interpolate predicted probs and label probs weight = tf.expand_dims(weight, 1) probs = py_utils.with_dependencies([ py_utils.assert_less_equal(weight, 1.), py_utils.assert_greater_equal(weight, 0.) ], (1.0 - weight) * pred_probs + weight * label_probs) return tf.math.log(probs), consistent
def FProp(self, theta, inputs): """Applies batch normalization. Using the implementation in github.com/ tensorflow/tpu/blob/master/models/official/amoeba_net/network_utils.py#L550 Args: theta: A nested map object containing weights' values of this layer and its children layers. inputs: The inputs tensor. Shaped [..., dim]. Returns: Output after applying batch normalization, with the same shape as 'inputs'. """ p = self.params inputs_dtype = inputs.dtype inputs = tf.cast(inputs, p.dtype) inputs = py_utils.with_dependencies([ py_utils.assert_shape_match([tf.shape(inputs)[-1]], tf.shape( theta.beta)) ], inputs) with tf.name_scope(p.name) as scope: if self.do_eval: outputs = tf.nn.batch_normalization(inputs, theta.moving_mean, theta.moving_variance, theta.beta, theta.gamma, p.epsilon) else: mean, variance = self._Moments(inputs, p.bn_group_size) mean = py_utils.CheckNumerics( mean, 'mean of {} failed numeric check'.format(scope)) variance = py_utils.CheckNumerics( variance, 'variance of {} failed numeric check'.format(scope)) outputs = tf.nn.batch_normalization(inputs, mean, variance, theta.beta, theta.gamma, p.epsilon) outputs.set_shape(inputs.get_shape()) return tf.cast(outputs, inputs_dtype)
def FProp(self, theta, input_batch): """Embeds source ids and transforms with TransformerStack. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. input_batch: A `.NestedMap` with fields: - ids: The inputs tensor. It is expected to be of shape [batch, time]. - paddings: The paddings tensor. Expected shape [batch, time]. - task_ids: If p.task_emb is provided, must contain per-token task ids of shape [batch, time]. Returns: A NestedMap containing - encoded: The encoded features, either a tensor of shape [time, batch, depth], or a list of tensors if is_transparent is set in transformer_stack. - padding: of shape [time, batch] - segment_id: [time, batch] if packed inputs are supported by the model (and all layers), or None otherwise. - embedded_inputs: [time, batch, depth] embedded inputs tokens without positional encodings. """ p = self.params with tf.name_scope(p.name): src_segment_id = None src_segment_pos = None input_ids = py_utils.with_dependencies([ py_utils.assert_shape_match(tf.shape(input_batch.ids), tf.shape(input_batch.paddings)), py_utils.assert_equal(tf.rank(input_batch.ids), 2) ], input_batch.ids) if (not py_utils.use_tpu() and tf.flags.FLAGS.transformer_encoder_truncates_inputs): max_seq_length = tf.cast( tf.reduce_max(tf.reduce_sum(1.0 - input_batch.paddings, 1)), tf.int32) paddings = py_utils.with_dependencies([ py_utils.assert_equal( tf.constant(True, tf.bool), tf.reduce_all( input_batch.paddings[:, max_seq_length:] > 0.5)) ], input_batch.paddings) input_ids = input_ids[:, :max_seq_length] paddings = paddings[:, :max_seq_length] if p.packed_input: src_segment_id = input_batch.segment_ids[:, : max_seq_length] src_segment_pos = input_batch.segment_pos[:, : max_seq_length] else: paddings = input_batch.paddings if p.packed_input: src_segment_id = input_batch.segment_ids src_segment_pos = input_batch.segment_pos max_time = tf.shape(input_ids)[1] # Input token embeddings + positional embeddings if not p.shared_emb: input_embs = self.token_emb.EmbLookup( theta.token_emb, tf.reshape(input_ids, [-1])) else: input_embs = self.softmax.EmbLookup( theta.softmax, tf.reshape(input_ids, [-1])) input_embs = tf.reshape(input_embs, [-1, max_time, p.token_emb.embedding_dim]) # [time, batch, dim] orig_input_embs = tf.transpose(input_embs, [1, 0, 2]) if p.packed_input: position_embs = self.position_emb.FPropWithPosition( theta.position_emb, src_segment_pos) else: position_embs = self.position_emb.FProp( theta.position_emb, max_time) position_embs = tf.reshape( position_embs, [1, max_time, p.token_emb.embedding_dim]) input_embs += position_embs if p.task_emb: input_embs += self.task_emb.EmbLookup(theta.task_emb, input_batch.task_ids) if p.model_dim != p.token_emb.embedding_dim: input_embs = self.emb_proj.FProp(theta.emb_proj, input_embs) paddings = tf.cast(tf.transpose(paddings), py_utils.FPropDtype(p)) if p.packed_input: src_segment_id = tf.transpose(src_segment_id) input_embs = self.input_dropout.FProp(theta.input_dropout, input_embs) # [time, batch, dim] transformer_input = tf.transpose(input_embs, [1, 0, 2]) if not self.do_eval and p.apply_source_mask: # Augment padding for masked source word positions. dtype = paddings.dtype source_mask = tf.where(tf.equal(input_ids, p.source_mask_id), tf.ones_like(input_ids, dtype=dtype), tf.zeros_like(input_ids, dtype=dtype)) # Make sure padding is between 0 and 1. paddings = tf.clip_by_value(paddings + tf.transpose(source_mask), 0.0, 1.0) encoded, padding, segment_id = self.transformer_stack.FProp( theta.transformer_stack, transformer_input, paddings, src_segment_id) return py_utils.NestedMap(encoded=encoded, padding=padding, segment_id=segment_id, embedded_inputs=orig_input_embs)
def MergeBeamSearchOutputs(max_hyps_per_beam, beam_search_outputs): """Merges beam search hyps from multiple decoders. Args: max_hyps_per_beam: the number of top hyps in the merged results. Must be less than or equal to total number of input hyps. beam_search_outputs: a list of BeamSearchDecodeOutput objects. Must share the same source_batch and max sequence length. Returns: A BeamSearchDecodeOutput object containing max_hyps_per_beam hypotheses per beam. """ source_batch = tf.shape(beam_search_outputs[0].topk_hyps)[0] value_dict = {} for output in beam_search_outputs: hyps_per_beam = py_utils.with_dependencies([ py_utils.assert_equal(source_batch, tf.shape(output.topk_hyps)[0]), ], tf.shape( output.topk_hyps)[1]) for k, v in six.iteritems(output._asdict()): if v is None: continue if k == 'done_hyps': v = tf.transpose(v) if k not in value_dict: value_dict[k] = [] value_dict[k].append( tf.reshape(v, [source_batch, hyps_per_beam, -1])) # Concatenate the tensors along the 'num_hyps_per_beam' dimension. concatenated = {} for k, values in six.iteritems(value_dict): if len(values) != len(beam_search_outputs): raise ValueError('Incomplete values for %s: %s' % (k, beam_search_outputs)) concatenated[k] = tf.concat(values, axis=1) scores = concatenated['topk_scores'] scores = tf.where(tf.equal(concatenated['topk_lens'], 0), tf.fill(tf.shape(scores), -1e6), scores) scores = tf.squeeze(scores, -1) # Select top max_hyps_per_beam indices per beam. _, top_indices = tf.nn.top_k(scores, max_hyps_per_beam) batch_ids = tf.tile(tf.expand_dims(tf.range(source_batch), -1), [1, max_hyps_per_beam]) # [source_batch, max_hyps_per_beam, 2] gather_indices = tf.stack([batch_ids, top_indices], axis=-1) # Gather the merged top hyps according to 'gather_indices'. top = beam_search_outputs[0]._asdict() total_hyps = source_batch * max_hyps_per_beam for k, v in six.iteritems(concatenated): v = tf.gather_nd(v, gather_indices) if k == 'done_hyps': v = tf.transpose(tf.reshape(v, [total_hyps, -1])) elif k == 'topk_hyps': v = tf.reshape(v, [source_batch, max_hyps_per_beam]) elif k == 'topk_ids': v = tf.reshape(v, [total_hyps, -1]) elif k in ('topk_lens', 'topk_scores', 'topk_decoded'): v = tf.reshape(v, [total_hyps]) else: raise ValueError('Unexpected field: %s' % k) top[k] = v return BeamSearchDecodeOutput(**top)
def ComputeAndUpdateMoments(self, theta, inputs, paddings=None): """Computes moments and updates state. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. inputs: The inputs tensor. Shaped [..., dim]. paddings: The paddings tensor. Shaped [..., 1], with the same rank as the input tensor. Returns: Tuple of (mean, variance, beta, gamma). """ p = self.params if paddings is None: paddings = self._GetDefaultPaddings(inputs) inputs = py_utils.with_dependencies([ py_utils.assert_shape_match([tf.shape(paddings)[-1]], [1]), ], inputs) with tf.name_scope(p.name): if self.do_eval: # The mean and variance used for normalization. norm_mean, norm_variance = (self.vars.moving_mean, self.vars.moving_variance) else: mean, variance = self._Moments(inputs, 1.0 - paddings, p.enable_cross_replica_sum_on_tpu) py_utils.UpdateBatchNormVars(self.vars.moving_mean, mean, self._decay) py_utils.UpdateBatchNormVars(self.vars.moving_variance, variance, self._decay) # Add some summaries for visualization. summary_utils.histogram('%s_mean' % p.name, tf.cast(mean, tf.float32)) summary_utils.histogram('%s_variance' % p.name, tf.cast(variance, tf.float32)) summary_utils.histogram('%s_moving_mean' % p.name, tf.cast(self.vars.moving_mean, tf.float32)) summary_utils.histogram('%s_moving_variance' % p.name, tf.cast(self.vars.moving_variance, tf.float32)) summary_utils.histogram( '%s_mean_diff' % p.name, tf.cast(mean - self.vars.moving_mean, tf.float32)) summary_utils.histogram( '%s_variance_diff' % p.name, tf.cast(variance - self.vars.moving_variance, tf.float32)) if p.use_moving_avg_in_training: # Use the global statistics for normalization. # Control dependencies on mean and variance make sure # moving_mean and variance will be updated for every training step. norm_mean = py_utils.with_dependencies([mean], self.vars.moving_mean) norm_variance = py_utils.with_dependencies([variance], self.vars.moving_variance) else: # Use the batch statistics for normalization. norm_mean = mean norm_variance = variance norm_mean = py_utils.CheckNumerics( norm_mean, 'mean of %s failed numeric check' % p.name) norm_variance = py_utils.CheckNumerics( norm_variance, 'variance of %s failed numeric check' % p.name) if p.use_moving_avg_in_training: beta = 0.0 gamma = 1.0 else: beta = theta.beta gamma = theta.gamma return norm_mean, norm_variance, beta, gamma