def _BeamSearchDecode(self, input_batch): p = self.params with tf.name_scope('fprop'), tf.name_scope(p.name): encoder_outputs = self.enc.FPropDefaultTheta(input_batch.src) encoder_outputs = self.dec.AddExtraDecodingInfo( encoder_outputs, input_batch.tgt) decoder_outs = self.dec.BeamSearchDecode(encoder_outputs) return self._ProcessBeamSearchDecodeOut(input_batch, encoder_outputs, decoder_outs)
def FProp(self, theta, *args): """FProp through multiple devices in the split. Args: theta: A NestedMap object containing weights' values of this layer and its children layers. *args: A tuple of Tensors (one or more). Every tensor's first dimension is the same (the batch dimension). Returns: The sub layer's output. """ p = self.params with tf.name_scope(p.name): assert all(isinstance(x, tf.Tensor) for x in args) cluster = self.cluster num = cluster.num_devices_per_split if num == 1: return self.sub.FProp(theta.sub, *args) inps = py_utils.SplitRecursively(list(args), num, axis=0) outs = [] for i, xs in enumerate(inps): device = cluster.WorkerDeviceInModelSplit(i) tf.logging.info('%d on device %s', i, device) with tf.device(device): ys = self.sub.FProp(theta.sub, *xs) if isinstance(ys, tuple): outs += [list(ys)] else: outs += [ys] # ys is a single tensor ret = py_utils.ConcatRecursively(outs, axis=0) if isinstance(ret, list): return tuple(ret) else: return ret # ys is a single tensor
def FProp(self, theta, source_id, source_paddings, target_id, target_paddings, source_segment_id, target_segment_id, source_pos_id, target_pos_id, source_task_id, target_task_id): p = self.params with tf.name_scope(p.name): src_task_emb, src_task_emb_theta = None, None if p.enc_task_emb: src_task_emb, src_task_emb_theta = self.src_task_emb, theta.src_task_emb source_vecs = self.GetEmbeddings( theta.src_token_emb, self.src_token_emb, theta.src_pos_emb, self.src_pos_emb, theta.src_dropout, self.src_dropout, source_id, source_pos_id, src_task_emb_theta, src_task_emb, source_task_id) target_vecs = None if p.add_tgt_embedding_layer: tgt_task_emb, tgt_task_emb_theta = None, None if p.enc_task_emb: tgt_task_emb, tgt_task_emb_theta = (self.tgt_task_emb, theta.tgt_task_emb) target_vecs = self.GetEmbeddings( theta.tgt_token_emb, self.tgt_token_emb, theta.tgt_pos_emb, self.tgt_pos_emb, theta.tgt_dropout, self.tgt_dropout, target_id, target_pos_id, tgt_task_emb_theta, tgt_task_emb, target_task_id) rets = (source_vecs, source_paddings, target_vecs, target_paddings, source_segment_id, target_segment_id, None, None) rets += (source_task_id, target_task_id) if p.ret_task_ids else () return rets
def FProp(self, theta, *args): p = self.params with tf.name_scope(p.name): args = _ToTuple(self.body.FProp(theta.body, *args)) for fetch in p.fetches: args += (self.body.GetDescendant(fetch).activation,) return args
def FProp(self, theta, *args): """Runs p.repeat copies of self.body.FProp independently. Args: theta: Layer model parameters. The shape of each variable in theta is always [p.repeat, ...]. And the i-th slice theta[i] becomes theta of the i-th copy of self.body. *args: Input arguments. The shape of each tensor in args is always [p.repeat, ....]. And the list [arg[i] for arg in args] becomes inputs to the i-th copy of self.body.FProp. Returns: The accumulated output_tensors. Each tensor t in the return has the shape [p.repeat, ....] and the tuple (t[i] for i in output_tensors) is the return tuple of the i-th self.body.FProp. """ p = self.params for arg in args: if arg is not None: arg = py_utils.HasShape(arg, [p.repeat], ndims=1) theta_stack = _MaybeStackExtraTheta(theta.body, self.body.vars, p.repeat) inputs = py_utils.NestedMap(theta=theta_stack, args=list(args)) # Infer out_shapes from FPropMeta. out_shapes = self._InferOutShapes(args) def _CellFn(unused_theta, unused_state0, inputs): """Recurrent cell function wrapper of body.FProp.""" # Sets shapes for both theta and inputs to self.body.FProp. for dst, src in zip(inputs.args + inputs.theta.Flatten(), list(args) + theta_stack.Flatten()): if src is not None: dst.set_shape(tf.TensorShape(src.shape.as_list()[1:])) # Runs the actual body.FProp fprop_outputs = self.body.FProp(inputs.theta, *inputs.args) fprop_outputs = _ToTuple(fprop_outputs) assert len(fprop_outputs) == len(out_shapes) # Passes fprop outputs to the next layer through state. state1 = py_utils.NestedMap(outputs=list(fprop_outputs)) return state1, py_utils.NestedMap() with tf.name_scope(p.name): # Initiate state0 with inferred output shapes. state0 = py_utils.NestedMap( outputs=[tf.zeros(shape, args[0].dtype) for shape in out_shapes]) # Runs body.FProp p.repeat times using Recurrent. acc_states, _ = recurrent.Recurrent( theta=py_utils.NestedMap(), state0=state0, inputs=inputs, cell_fn=_CellFn) # Retrieves fprop outputs from state1 and sets shapes. output_tensors = tuple(acc_states.outputs) for out_idx in range(len(output_tensors)): output_tensors[out_idx].set_shape( tf.TensorShape([p.repeat] + out_shapes[out_idx].as_list())) return output_tensors[0] if len(args) == 1 else tuple(output_tensors)
def FProp(self, theta, *args): r"""Applies lambda(x, \*kwargs) for every non-None arg.""" del theta p = self.params with tf.name_scope(p.name): ret = [None if x is None else p.fn(x, **p.kwargs) for x in args] return tuple(ret) if len(ret) > 1 else ret[0]
def PrepareExternalInputs(self, theta, external_inputs): """Prepares external inputs for each sub-step. The external_inputs parameter of this method is processed by the external_inputs of each sub-step, then processed by the sub-step's PrepareExternalInputs method. Args: theta: variables used by sub-steps. external_inputs: A NestedMap of [n_batch, ...] tensors. Returns: A NestedMap of prepared inputs, where the keys are the names of each sub-step. """ graph_tensors = builder_layers.GraphTensors() graph_tensors.StoreTensor('external_inputs', external_inputs) prepared_inputs = py_utils.NestedMap() with tf.name_scope(self.params.name): for seq in self._seq: if seq.external_signature: template = py_utils.NestedMap( inputs=seq.external_signature.inputs) packed = template.Transform(graph_tensors.GetTensor) seq_external_inputs = packed.inputs[0] prepared_inputs[seq.name] = seq.step.PrepareExternalInputs( theta[seq.name], seq_external_inputs) else: prepared_inputs[seq.name] = py_utils.NestedMap() return prepared_inputs
def FProp(self, theta, source_vecs, source_paddings, target_vecs, target_paddings, source_segment_id, target_segment_id, transparent_acc, transparent_acc_helper, source_task_id=None, target_task_id=None): p = self.params with tf.name_scope(p.name): if p.has_aux_atten: # Decoder FProp return _common_gpipe_transformer_decoder_fprop( self, GPipeTransformerLayer, theta, source_vecs, source_paddings, target_vecs, target_paddings, source_segment_id, target_segment_id, transparent_acc, transparent_acc_helper, source_task_id, target_task_id) else: # Encoder FProp return _common_gpipe_transformer_encoder_fprop( self, GPipeTransformerLayer, theta, source_vecs, source_paddings, target_vecs, target_paddings, source_segment_id, target_segment_id, transparent_acc, transparent_acc_helper, source_task_id, target_task_id)
def fast_gather(values, ids, ids_size, max_value=None, axis=0, batch_major_state=True): """Fast implementation of gather on TPUs. Args: values: Values to gather from. ids: ids (rows to gather) ids_size: id space size. max_value: Optional hint on maximum value for int32 that allows to speed up the gather operation. axis: axis to gather on. Defaults to 0 (rows). batch_major_state: Whether the values to gather from use batch major or not. Defaults to True. Returns: Gathered values. Raises: Value error if values is type int64. """ values = tf.convert_to_tensor(values) ids = tf.convert_to_tensor(ids) with tf.name_scope("fast_gather"): return _Gatherer(ids, ids_size)(values, max_value=max_value, axis=axis, batch_major_state=batch_major_state)
def Step(recurrent_theta, state0, inputs): """Computes one decoder step.""" del inputs with tf.name_scope('single_sampler_step'): # Compute logits and states. bs_result, bs_state1 = pre_step_callback( recurrent_theta.theta, recurrent_theta.encoder_outputs, tf.expand_dims(state0.ids, 1), # [batch, 1]. state0.bs_state, num_hyps_per_beam=1) batch = tf.shape(bs_result.log_probs)[0] state1 = py_utils.NestedMap(timestep=state0.timestep + 1) state1.logits = bs_result.log_probs # Sample ids from logits. [batch]. state1.ids = tf.reshape( tf.random.stateless_categorical( state1.logits / p.temperature, num_samples=1, seed=tf.stack( [recurrent_theta.random_seed, state0.timestep]), dtype=state0.ids.dtype, name='sample_next_id'), [batch]) if 'is_last_chunk' in bs_result and p.target_eoc_id >= 0: state1.ids = tf.where( tf.math.logical_and( bs_result.is_last_chunk, tf.equal(state1.ids, p.target_eoc_id)), tf.fill(tf.shape(state1.ids), p.target_eos_id), state1.ids) state1.bs_state = post_step_callback( recurrent_theta.theta, recurrent_theta.encoder_outputs, state1.ids, bs_state1) return state1, py_utils.NestedMap()
def FProp(self, theta, x): tf.logging.vlog(1, 'layer %s', self.params.name) with tf.name_scope(self.params.name): for (name, ch) in self._seq: th = theta[name] tf.logging.vlog(1, ' call %s %s %s', ch.params.name, ch, x) x = ch.FProp(th, x) return x
def FProp(self, theta, input_batch): """Encodes source as represented by `inputs` and `paddings`. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. input_batch: A `.NestedMap` with fields: - ids: The inputs tensor. It is expected to be of shape [batch, time]. - paddings: The paddings tensor. Expected shape [batch, time]. Returns: A NestedMap containing: - encoded: The encoded features, a tensor of shape [time, batch, depth] - padding: of shape [time, batch] - segment_id: [time, batch] if packed inputs are supported by the model (and all layers), or None otherwise. """ p = self.params src_segment_id = None with tf.name_scope(p.name): # Now the rnn layers. inputs = py_utils.with_dependencies([ py_utils.assert_shape_match(tf.shape(input_batch.ids), [-1, -1]), py_utils.assert_shape_match(tf.shape(input_batch.ids), tf.shape(input_batch.paddings)) ], tf.transpose(input_batch.ids)) paddings = tf.expand_dims(tf.transpose(input_batch.paddings), 2) xs = self.emb.EmbLookup(theta.emb, inputs) xs = self.ApplyClipping(theta, xs) self._emb_out = xs ps = paddings # When cc_schedule is specified, make sure lstm_tpl is QuantizedLSTMCell # with the same cc_schedule so that the RNN layer output is within # clipping range. xs = self.rnn[0].FProp(theta.rnn[0], xs, ps) xs = self.dropout.FProp(theta.dropout, xs) for i in range(1, p.num_lstm_layers): layer = self.rnn[i] ys, _ = layer.FProp(theta.rnn[i], xs, ps) ys = self.dropout.FProp(theta.dropout, ys) if hasattr(layer.params, 'cell'): layer_params = layer.params.cell else: layer_params = layer.params if layer_params.num_input_nodes == layer_params.num_output_nodes: xs += ys # Residual skip xs = self.ApplyClipping(theta, xs) else: # When cc_schedule is specified, make sure lstm_tpl is # QuantizedLSTMCell with the same cc_schedule so that the RNN layer # output is within clipping range. xs = ys return py_utils.NestedMap(encoded=xs, padding=tf.squeeze(ps, [2]), segment_id=src_segment_id)
def FProp(self, theta, *args): p = self.params # Collects all variable key and values into sets. theta_stack = _MaybeStackExtraTheta(theta.body, self.body.vars, p.repeat) def _ArgsToState(arg_list): """Returns a NestedMap from a list of FProp args.""" state = py_utils.NestedMap() # Maintains a mapping from arg_idx to tensor. states cannot contains # None tensors. for idx in range(len(args)): if arg_list[idx] is not None: state['_s{}'.format(idx)] = arg_list[idx] return state def _StateToArgs(state): """Returns a list of FProp args from a NestedMap.""" arg_list = [] for idx in range(len(args)): attr = '_s{}'.format(idx) arg_list.append(state[attr] if attr in state else None) if arg_list[-1] is not None: arg_list[-1].set_shape(args[idx].shape) return arg_list def _CellFn(unused_theta, state0, theta_i): """Recurrent cell function wrapper of body.FProp.""" # Retrieves fprop arguments from state and sets shapes. frop_inputs = _StateToArgs(state0) # Sets shapes for theta_i as well. for dst, src in zip(theta_i.Flatten(), theta_stack.Flatten()): if src is not None: dst.set_shape(tf.TensorShape(src.shape.as_list()[1:])) # Runs the actual body.FProp frop_outputs = self.body.FProp(theta_i, *frop_inputs) frop_outputs = _ToTuple(frop_outputs) assert len(frop_outputs) == len(frop_inputs) # Passes fprop outputs to the next layer through state. state1 = _ArgsToState(frop_outputs) return state1, py_utils.NestedMap() with tf.name_scope(p.name): # Add FProp arg list to state0. state0 = _ArgsToState(args) # Runs body.FProp k times using Recurrent where k = dim 0 of var_nmap. _, state1 = recurrent.Recurrent( theta=py_utils.NestedMap(), state0=state0, inputs=theta_stack, # Pass cell_fn theta through inputs. cell_fn=_CellFn) # Retrieves fprop outputs from state1 and sets shapes. output_tensors = _StateToArgs(state1) return output_tensors[0] if len(args) == 1 else tuple(output_tensors)
def FProp(self, theta, current_step): p = self.params assert p.total_steps > 0 assert p.initial_value > p.final_value with tf.name_scope(p.name): decay_gap = p.initial_value - p.final_value return p.final_value + 0.5 * decay_gap * (1 + tf.cos( math.pi * tf.minimum(1.0, tf.cast(current_step, tf.float32) / p.total_steps)))
def FProp(self, theta, *args): r"""Applies a function (p.fn) on args. Args: theta: Unused. *args: A tuple of Tensors (one or more). Returns: fn(\*args). """ with tf.name_scope(self.params.name): return self.params.fn(*args)
def FProp(self, theta, input_batch): p = self.params with tf.name_scope(p.name): inputs = py_utils.with_dependencies([ py_utils.assert_shape_match(tf.shape(input_batch.ids), [-1, -1]), py_utils.assert_shape_match(tf.shape(input_batch.ids), tf.shape(input_batch.paddings)) ], tf.transpose(input_batch.ids)) paddings = tf.expand_dims(tf.transpose(input_batch.paddings), 2) if p.packed_input: src_segment_id = tf.expand_dims( tf.transpose(input_batch.segment_ids), 2) else: src_segment_id = None xs = self.emb.EmbLookup(theta.emb, inputs) xs = self.ApplyClipping(theta, xs) summary_utils.histogram('input_emb', xs) xs = self.dropout.FProp(theta.dropout, xs) ps = paddings # Now the rnn layers. outputs_list = [] for i in range(0, p.num_lstm_layers): layer = self.rnn[i] ys = layer.FProp(theta.rnn[i], xs, ps, segment_id=src_segment_id) ys = self.dropout.FProp(theta.dropout, ys) if i >= p.residual_start: xs += ys # Residual skip xs = self.ApplyClipping(theta, xs) else: xs = ys outputs_list.append(xs) summary_utils.histogram('layer_out_%s' % i, xs) if p.is_transparent: xs = self.transparent_merger.FProp(theta.transparent_merger, outputs_list) if p.lstm_cell_size * 2 != p.encoder_out_dim: # Project to the right depth. xs = self.final_proj.FProp(theta.final_proj, xs, ps) summary_utils.histogram('final_proj_out', xs) if src_segment_id is not None: src_segment_id = tf.squeeze(src_segment_id, [2]) return py_utils.NestedMap(encoded=xs, padding=tf.squeeze(ps, [2]), segment_id=src_segment_id)
def FProp(self, theta, inputs, paddings): """Apply convolution to inputs. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. inputs: The inputs tensor. It is expected to be of shape [batch, time, frequency, channel]. The time dimension corresponds to the height dimension as in images and the frequency dimension corresponds to the width dimension as in images. paddings: The paddings tensor, expected to be of shape [batch, time]. Returns: outputs, out_paddings pair. """ p = self.params with tf.name_scope(p.name): inputs = py_utils.with_dependencies([ py_utils.assert_shape_match(tf.shape(paddings), [-1, -1]), py_utils.assert_shape_match( tf.shape(inputs), tf.concat([ tf.shape(paddings), [-1, symbolic.ToStatic(self.input_channels)] ], 0)) ], inputs) def _ApplyPadding(tensor_in, padding_in): padding_expanded = tf.expand_dims(tf.expand_dims(padding_in, -1), -1) return tensor_in * (1.0 - padding_expanded) # Zeroing out padded inputs. inputs = _ApplyPadding(inputs, paddings) # Apply conv on 'inputs'. out = self._ApplyConv(theta, inputs) if p.partial_conv: out = self._RescaleBoundary(out, paddings) # NOTE: this may be slightly inaccurate when p.dilation_rate[0] > 1. # But there's likely no real problems. Trying to set it gives an error: # pooling with SAME padding is not implemented for dilation_rate > 1. # NOTE: we use window=p.filter_stride[0] to be compatible with legacy # implementation. Consider updating it to be the actual shape. conv_padding = ComputeConvOutputPadding( paddings, window=p.filter_stride[0], stride=p.filter_stride[0]) # Assuming padded nodes will be properly zero-ed out if necessary by # sub-sequent layers. # out = _ApplyPadding(out, conv_padding) out = py_utils.HasShape( out, symbolic.ToStatic(self.OutShape(tf.shape(inputs)))) return out, conv_padding
def FProp(self, theta, inputs): """Adds bias to inputs. Args: theta: A NestedMap object containing weights' values of this layer and its children layers. inputs: The inputs tensor. Shaped [..., dims]. Returns: Inputs plus bias. """ with tf.name_scope(self.params.name): return inputs + theta.b
def _Traverse(layer): """Adds accumulators to layer and its descendant layers.""" if isinstance(layer, (list, tuple)): for layer_i in layer: _Traverse(layer_i) return with tf.name_scope(layer.params.name): for cost_metric_name in COST_METRICS: dtype = COST_METRICS[cost_metric_name] layer.RegisterAccumulator( cost_metric_name, bn_layers.AddingAccumulator(shape=[], dtype=dtype)) for _, child in sorted(layer.children.items()): _Traverse(child)
def FProp(self, theta, *args): p = self.params with tf.name_scope(p.name): # Computes sub layers in parallel. outputs = [] for (name, ch) in self._seq: th = theta[name] out = ch.FProp(th, *args) if isinstance(out, (list, tuple)): outputs.append(tuple(out)) else: outputs.append((out,)) rets = p.merge(outputs) return rets if len(rets) > 1 else rets[0]
def CollectVarHistogram(vs_gs): """Adds histogram summaries for variables and gradients.""" for name, (var, grad) in vs_gs.FlattenItems(): name = py_utils.SanitizeScopeKey(name) with tf.device(var.device), tf.name_scope(name + '/summary'): if isinstance(grad, tf.IndexedSlices): var = tf.gather(var, grad.indices) grad = grad.values if var.dtype.is_complex: var = tf.abs(var) grad = tf.abs(grad) histogram('var_hist/' + name, var) histogram('grad_hist/' + name, grad)
def FProp(self, theta, input_batch, state0=None): p = self.params src_segment_id = None with tf.name_scope(p.name): # Reshape to [t, b] inputs = py_utils.with_dependencies([ py_utils.assert_shape_match(tf.shape(input_batch.ids), [-1, -1]), py_utils.assert_shape_match(tf.shape(input_batch.ids), tf.shape(input_batch.paddings)) ], tf.transpose(input_batch.ids)) paddings = tf.expand_dims(tf.transpose(input_batch.paddings), 2) # Setup streaming states. if not state0: state0 = self.zero_state(theta, tf.shape(inputs)[1]) state1 = py_utils.NestedMap(rnn=[None] * p.num_lstm_layers) xs = self.emb.EmbLookup(theta.emb, inputs) xs = self.ApplyClipping(theta, xs) summary_utils.histogram('input_emb', xs) xs = self.dropout.FProp(theta.dropout, xs) ps = paddings # Now the rnn layers. outputs_list = [] for i in range(0, p.num_lstm_layers): layer = self.rnn[i] ys, state1.rnn[i] = layer.FProp(theta.rnn[i], xs, ps, state0=state0.rnn[i]) ys = self.dropout.FProp(theta.dropout, ys) if i >= p.residual_start: xs += ys # Residual skip xs = self.ApplyClipping(theta, xs) else: xs = ys outputs_list.append(xs) summary_utils.histogram('layer_out_%s' % i, xs) if p.is_transparent: xs = self.transparent_merger.FProp(theta.transparent_merger, outputs_list) return py_utils.NestedMap(encoded=xs, padding=tf.squeeze(ps, [2]), segment_id=src_segment_id, state=state1)
def FProp(self, theta, *args): p = self.params with tf.name_scope(p.name) as name_scope: for i, arg in enumerate(args): if not isinstance(arg, tf.Tensor): tf.logging.info( 'FProp non-Tensor input in {}: arg_{} arg = {}'.format( name_scope, i, arg)) else: tf.logging.info( 'FProp inputs in {}: arg_{} shape = {} dtype = {}'.format( name_scope, i, arg.shape, arg.dtype.name)) if len(args) == 1: return args[0] else: return args
def TraverseLayer(layer, fn): """Traverses the layer tree and invokes fn(node) on each node. Args: layer: a BaseLayer. fn: a function of (layer, layer_theta) -> None. """ if isinstance(layer, (list, tuple)): for layer_i in layer: TraverseLayer(layer_i, fn) return with tf.name_scope(layer.params.name): fn(layer) # Traverse all children in alphabetical order. for _, child in sorted(layer.children.items()): TraverseLayer(child, fn)
def ZeroState(self, theta, prepared_inputs, batch_size): """Creates a zero state NestedMap for this step. Args: theta: variables used by sub-steps. prepared_inputs: Output from a call to PrepareExternalInputs. batch_size: The number of items in the batch that FProp will process. Returns: A NestedMap of ZeroState results for each sub-step. """ state0 = py_utils.NestedMap() with tf.name_scope(self.params.name): for seq in self._seq: state0[seq.name] = seq.step.ZeroState( theta[seq.name], prepared_inputs[seq.name], batch_size) return state0
def FProp(self, theta, inputs, *args): p = self.params with tf.name_scope(p.name) as scope: expert_dist = self._GetExpertDist(theta, inputs, *args) if not self.do_eval: summary_utils.histogram('soft_cond_{}'.format(scope), expert_dist) # Excludes non-variable extra_theta like global_step. var_set = set([key for key, _ in self.body.vars.FlattenItems()]) values = [] for key, value in theta.body.FlattenItems(): if key in var_set and value is not None: # Weighted average for all variables created in the body layer. value = tf.einsum('i,i...->...', expert_dist, value) values.append(value) weighted_theta = theta.body.Pack(values) return self.body.FProp(weighted_theta, inputs, *args)
def FProp(self, theta, prepared_inputs, step_inputs, padding, state0): """A single inference step for this step graph. Args: theta: variables used by sub-steps. prepared_inputs: A NestedMap containing external_inputs that were pre-processed by the PrepareExternalInputs method of each sub-step. The keys are the names of the sub-steps. step_inputs: A NestedMap of [batch, ...] tensors. The structure of this depends on the graph implementation. padding: A 0/1 float tensor of shape [batch_size]; 1.0 means that this batch element is empty in this step. state0: A NestedMap of state variables produced by either ZeroState or a previous invocation of this FProp step. The keys are the names of the sub-steps. Returns: (output, state1), both of which are NestedMaps. output is implementation-dependent and is defined by the output_signature parameter. state1 is a NestedMap where the keys are names of sub-steps and the values are state outputs from their FProp methods. """ p = self.params graph_tensors = builder_layers.GraphTensors() graph_tensors.StoreTensor('prepared_inputs', prepared_inputs) graph_tensors.StoreTensor('step_inputs', step_inputs) state1 = py_utils.NestedMap() with tf.name_scope(p.name): for seq in self._seq: tf.logging.vlog(1, 'GraphStep: call %s', seq.name) external = None if seq.external_signature: external = prepared_inputs[seq.name] template = py_utils.NestedMap(inputs=seq.signature.inputs) packed = template.Transform(graph_tensors.GetTensor) input_args = packed.inputs[0] out, seq_state1 = seq.step.FProp(theta[seq.name], external, input_args, padding, state0[seq.name]) graph_tensors.StoreTensor(seq.signature.outputs[0], out) state1[seq.name] = seq_state1 template = py_utils.NestedMap(inputs=self.output_signature.inputs) output_tensors = template.Transform(graph_tensors.GetTensor).inputs[0] return output_tensors, state1
def FProp(self, theta, inputs, paddings=None): """Apply batch normalization. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. inputs: The inputs tensor. Shaped [..., dim]. paddings: The paddings tensor. Shaped [..., 1], with the same rank as the input tensor. Returns: Output after applying batch normalization, with the same shape as 'inputs'. """ p = self.params if paddings is None: paddings = self._GetDefaultPaddings(inputs) with tf.name_scope(p.name): norm_mean, norm_variance, beta, gamma = self.ComputeAndUpdateMoments( theta, inputs, paddings) with tf.control_dependencies([ py_utils.assert_greater_equal( norm_variance, tf.zeros_like(norm_variance)), py_utils.assert_shape_match([tf.shape(inputs)[-1]], tf.shape(norm_mean)), py_utils.assert_shape_match([tf.shape(inputs)[-1]], tf.shape(norm_variance)), ]): if p.use_fused_batch_norm_for_eval and self.do_eval: bn_output, _, _ = nn.fused_batch_norm(inputs, gamma, beta, norm_mean, norm_variance, self._epsilon, is_training=False) else: bn_output = tf.nn.batch_normalization( inputs, norm_mean, norm_variance, beta, gamma, self._epsilon) if p.set_padded_output_to_zero: bn_output *= 1.0 - paddings return bn_output
def FProp(self, theta, source_vecs, source_paddings, target_vecs, target_paddings, source_segment_id, target_segment_id, transparent_acc, transparent_acc_helper, source_task_id=None, target_task_id=None): with tf.name_scope(self.params.name): return _common_gpipe_transformer_decoder_fprop( self, GPipeEvolvedTransformerDecoderLayer, theta, source_vecs, source_paddings, target_vecs, target_paddings, source_segment_id, target_segment_id, transparent_acc, transparent_acc_helper, source_task_id, target_task_id)
def FProp(self, theta, current_step): p = self.params with tf.name_scope(p.name): steps = self._best_step best_step = steps[0] last_step = steps[1] ref_step = tf.maximum(self._ref_step, best_step) f = self._cur_factor # Decay if no improvement within window. new_factor = tf.where(last_step - ref_step < p.window, f, tf.maximum(p.min_factor, f * p.decay)) # Update ref_step if we decayed. new_step = tf.where(tf.equal(new_factor, f), ref_step, last_step) update_step = tf.assign(self._ref_step, new_step) with tf.control_dependencies([update_step]): return tf.assign(self._cur_factor, new_factor)