def FProp(self, theta, inputs): p = self.params with tf.name_scope(p.name): vec = self.conv.FProp(theta.conv, inputs.vec) return py_utils.NestedMap(vec=vec, paddings=inputs.paddings)
def FProp(self, theta, external_inputs, step_inputs, padding, state0): return py_utils.NestedMap( output=':'.join(step_inputs.inputs + [external_inputs]) + state0), (state0 + ':'.join(step_inputs.inputs))
def GenericInput(processor, **kwargs): """Builds a generic input pipeline. Example usage:: def ParseRecord(record): # Given a tf.string record, return a (NestedMap, bucketing key) pair. feature_map = ... features = tf.parse_single_example(record, feature_map) # Each example is represented by a NestedMap of tensors (without a # batch dimension). example = py_utils.NestedMap(field1=..., field2=...) # bucketing_key is a scalar convertible to tf.int32. # Use 1 if all examples are of the same size. bucketing_key = 1 return example, bucketing_key input_batch, bucket_keys = GenericInput(ParseRecord, file_pattern=..., ...) # input_batch is a NestedMap of tensors, where dim 0 of each tensor # represents the batch dimension. input_batch.field1 = ... ParseRecord can also take both 'source_id' and 'record' as inputs (the arg names must be exactly 'source_id' and 'record'): def ParseRecord(source_id, record): # Given a tf.int32 source_id and a tf.string record, return a (NestedMap, # bucketing key) pair. example = py_utils.NestedMap(source_id=source_id, ...) ... return example, bucketing_key input_batch, bucket_keys = GenericInput(ParseRecord, file_pattern=..., ...) Args: processor: a function that takes either a tf.string record or a (source_id: tf.int32, record: tf.string) pair as input and returns a tuple (output, bucketing_key). `output` must be a NestedMap or a list of tensors representing an example. `bucketing_key` must be a scalar convertible to a tf.int32 tensor that represents the bucketing key (e.g., sequence length for sequence inputs). If `bucketing_key` is a negative number, the record is dropped. **kwargs: additional keyword args for x_ops.generic_input. Returns: A tuple of (outputs, bucket_keys): - outputs: a NestedMap or a list of tensors, similar to `processor`'s return, except every tensor will have an additional dimension 0 that represents the batch dimension. - bucket_keys: a tf.int32 vector. """ output_tmpl = py_utils.NestedMap() def _FlatOutputProcessor(source_id, record): """Returns a flattened list of 'processor(inputs)'.""" processor_spec = tf_inspect.getargspec(processor) tf.logging.debug('GenericInput.processor.argspec=%s', processor_spec) processor_args = set(processor_spec.args) - set(['self']) if len(processor_args) == 1: output, bucketing_key = processor(record) elif processor_args == set(['source_id', 'record']): output, bucketing_key = processor(source_id=source_id, record=record) else: raise ValueError( 'GenericInput: processor should take either a single arg ' 'or two args named as "source_id" and "record". ' 'Actual: %s' % processor_args) if isinstance(output, list): assert output assert all(isinstance(x, tf.Tensor) for x in output), '{}'.format(output) else: assert isinstance(output, py_utils.NestedMap), '{}'.format(output) assert output assert all(isinstance(x, tf.Tensor) for x in output.Flatten()), '{}'.format( output.DebugString()) bucketing_key = tf.cast(bucketing_key, tf.int32) tf.logging.debug('Processor outputs=%s bucketing_key=%s', output, bucketing_key) output_tmpl.out_values = output flat_output_tmpl = output_tmpl.Flatten() tf.logging.debug('Processor flat outputs=%s', flat_output_tmpl) tf.logging.debug('extra_inputs=%s extra_args=%s extra_vars=%s', function.get_extra_inputs(), function.get_extra_args(), function.get_extra_vars()) assert not function.get_extra_args(), ( 'fns {} is not pure: extra_args={}'.format( processor, function.get_extra_args())) return flat_output_tmpl + [bucketing_key] proc_fn = tf.Defun(tf.int32, tf.string)(_FlatOutputProcessor) out_types = [ tf.DType(a.type) for a in proc_fn.definition.signature.output_arg ] assert out_types[-1] == tf.int32, ('%s is not expected.' % out_types[-1]) flat_outputs, bucket_keys = ops.gen_x_ops.generic_input( processor=proc_fn, out_types=out_types[:-1], **kwargs) tf.logging.debug('x_ops.generic_input flat_outputs=%s', flat_outputs) # Pack flat_outputs to outputs. outputs = output_tmpl.Pack(flat_outputs).out_values tf.logging.debug('x_ops.generic_input outputs=%s', outputs) return outputs, bucket_keys
def Decode(self, input_batch): """Decode an input batch, computing predicted bboxes from residuals.""" p = self.params predictions = self.ComputePredictions(self.theta, input_batch) bboxes_and_logits = self._BBoxesAndLogits(input_batch, predictions) predicted_bboxes = bboxes_and_logits.predicted_bboxes batch_size, num_bboxes, _ = py_utils.GetShape(predicted_bboxes, 3) classification_logits = bboxes_and_logits.classification_logits classification_logits = py_utils.HasShape( classification_logits, [batch_size, num_bboxes, p.num_classes]) classification_scores = tf.sigmoid(classification_logits) _, per_example_dict = self.ComputeLoss(self.theta, predictions, input_batch) if 'score_scaler' in per_example_dict: classification_scores *= per_example_dict['score_scaler'] with tf.device('/cpu:0'): # Decode the predicted bboxes, performing NMS. _, per_cls_bboxes, per_cls_bbox_scores, per_cls_valid_mask = ( detection_decoder.DecodeWithNMS( predicted_bboxes, classification_scores, nms_iou_threshold=p.nms_iou_threshold, score_threshold=p.nms_score_threshold, max_boxes_per_class=p.max_nms_boxes, use_oriented_per_class_nms=p.use_oriented_per_class_nms)) # per_cls_valid_mask is [batch, num_classes, num_boxes] Tensor that # indicates which boxes were selected by NMS. Each example will have a # different number of chosen bboxes, so the mask is present to allow us # to keep the boxes as a batched dense Tensor. # # We mask the scores by the per_cls_valid_mask so that none of these boxes # will be interpreted as valid. per_cls_bbox_scores *= per_cls_valid_mask visualization_weights = py_utils.HasShape( per_cls_bbox_scores, [batch_size, p.num_classes, p.max_nms_boxes]) # For top down visualization, filter boxes whose scores are not above the # visualization threshold. visualization_weights = tf.where( tf.greater_equal(visualization_weights, p.visualization_classification_threshold), visualization_weights, tf.zeros_like(visualization_weights)) model_outputs = py_utils.NestedMap() model_outputs.per_class_predicted_bboxes = per_cls_bboxes model_outputs.per_class_predicted_bbox_scores = per_cls_bbox_scores model_outputs.per_class_valid_mask = per_cls_valid_mask decoder_outputs = py_utils.NestedMap({ 'per_class_predicted_bboxes': per_cls_bboxes, 'per_class_predicted_bbox_scores': per_cls_bbox_scores, 'per_class_valid_mask': per_cls_valid_mask, 'visualization_weights': visualization_weights, }) decoder_outputs.update( self.output_decoder.ProcessOutputs(input_batch, model_outputs)) # Produce global step as an output (which is the step # of the checkpoint being decoded.) decoder_outputs.global_step = py_utils.GetGlobalStep() return decoder_outputs
def FProp(self, theta, input_batch): """Embeds source ids and transforms with TransformerStack. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. input_batch: A `.NestedMap` with fields: - ids: The inputs tensor. It is expected to be of shape [batch, time]. - paddings: The paddings tensor. Expected shape [batch, time]. Returns: A NestedMap containing: - encoded: The encoded features, either a tensor of shape [time, batch, depth], or a list of tensors if is_transparent is set in transformer_stack. - padding: of shape [time, batch] - segment_id: [time, batch] if packed inputs are supported by the model (and all layers), or None otherwise. - embedded_inputs: [time, batch, depth] embedded inputs tokens without positional encodings. """ p = self.params with tf.name_scope(p.name): src_segment_id = None src_segment_pos = None input_ids = py_utils.with_dependencies([ py_utils.assert_shape_match(tf.shape(input_batch.ids), tf.shape(input_batch.paddings)), py_utils.assert_equal(tf.rank(input_batch.ids), 2) ], input_batch.ids) if (not py_utils.use_tpu() and tf.flags.FLAGS.transformer_encoder_truncates_inputs): max_seq_length = tf.cast( tf.reduce_max(tf.reduce_sum(1.0 - input_batch.paddings, 1)), tf.int32) paddings = py_utils.with_dependencies([ py_utils.assert_equal( tf.constant(True, tf.bool), tf.reduce_all( input_batch.paddings[:, max_seq_length:] > 0.5)) ], input_batch.paddings) input_ids = input_ids[:, :max_seq_length] paddings = paddings[:, :max_seq_length] if p.packed_input: src_segment_id = input_batch.segment_ids[:, : max_seq_length] src_segment_pos = input_batch.segment_pos[:, : max_seq_length] else: paddings = input_batch.paddings if p.packed_input: src_segment_id = input_batch.segment_ids src_segment_pos = input_batch.segment_pos max_time = tf.shape(input_ids)[1] # Input token embeddings + positional embeddings input_embs = self.token_emb.EmbLookup(theta.token_emb, tf.reshape(input_ids, [-1])) input_embs = tf.reshape(input_embs, [-1, max_time, p.token_emb.embedding_dim]) # [time, batch, dim] orig_input_embs = tf.transpose(input_embs, [1, 0, 2]) if p.packed_input: position_embs = self.position_emb.FPropWithPosition( theta.position_emb, src_segment_pos) else: position_embs = self.position_emb.FProp( theta.position_emb, max_time) position_embs = tf.reshape( position_embs, [1, max_time, p.token_emb.embedding_dim]) input_embs += position_embs if p.task_emb: input_embs += self.task_emb.EmbLookup(theta.task_emb, input_batch.task_ids) if p.model_dim != p.token_emb.embedding_dim: input_embs = self.emb_proj.FProp(theta.emb_proj, input_embs) paddings = tf.transpose(paddings) if p.packed_input: src_segment_id = tf.transpose(src_segment_id) input_embs = self.input_dropout.FProp(theta.input_dropout, input_embs) # [time, batch, dim] transformer_input = tf.transpose(input_embs, [1, 0, 2]) encoded, padding, segment_id = self.transformer_stack.FProp( theta.transformer_stack, transformer_input, paddings, src_segment_id) return py_utils.NestedMap(encoded=encoded, padding=padding, segment_id=segment_id, embedded_inputs=orig_input_embs)
def BuildTpuSubgraph(self): if self._ml_perf_log: mlp_log.mlperf_print('global_batch_size', self._ml_perf.global_batch_size) mlp_log.mlperf_print('max_sequence_length', self._ml_perf.max_sequence_length) mlp_log.mlperf_print('opt_name', self._ml_perf.optimizer_name) mlp_log.mlperf_print('opt_base_learning_rate', self._ml_perf.base_learning_rate) mlp_log.mlperf_print('opt_learning_rate_warmup_steps', self._ml_perf.warmup_steps) with py_utils.OpportunisticVariableReuseScope(True): self._eval_metrics = metrics.TpuEvalMetrics() data_parallelism = self.data_parallelism def TpuTrainStep(): """Train a shard of a batch on a single TPU core. Do not calculate loss metrics. Returns: [train_op]. """ self._train_model = self._train_task_params.Instantiate() self._model = self._train_model self._train_model.ConstructFPropBPropGraph() return [self._train_model.GetTask().train_op] @tpu_function.on_device_training_loop def TpuTrain(): loop_result = tpu_training_loop.repeat( self._train_steps_per_loop, TpuTrainStep, inputs=[], name='train_loop') return loop_result py_utils.ResetStepSeed() def _DecodeFn(): """Decode call to be compiled for TPU.""" with py_utils.OpportunisticVariableReuseScope(True): with cluster_factory.SetEval(True): self._decode_model = self._decode_task_params.Instantiate() self._decode_model_task = self._decode_model.GetTask() if py_utils.use_tpu(): input_batch = self._decode_model_task.input_generator.CreateTpuFeeds( ) else: input_batch = self._decode_model_task.input_generator.SplitInputBatch( self.cluster.num_splits_per_client) metrics_dict = self._decode_model_task.Decode(input_batch) self.metrics_nm = py_utils.NestedMap(metrics_dict) return self.metrics_nm.Flatten() def TrainAndDecode(): with tf.control_dependencies([TpuTrain()]): return _DecodeFn() self._compile_op, batch_parallel_res = tpu.split_compile_and_shard( TrainAndDecode, num_shards=data_parallelism, device_assignment=py_utils.GetTpuDeviceAssignment()) self.metrics = py_utils.NestedMap(self.metrics_nm) self.metrics = self.metrics.Pack(batch_parallel_res) return None
def _PreBeamSearchStepCallback(self, theta, source_encs, source_paddings, step_ids, states, num_hyps_per_beam, additional_source_info=None): """Returns logits for sampling ids and the next model states. Args: source_encs: A tensor of shape [src_len, src_batch, source_dim]. source_paddings: A tensor of shape [src_len, src_batch]. step_ids: A tensor of shape [tgt_batch, 1]. states: A `.NestedMap` of tensors representing states that the clients would like to keep track of for each of the active hyps. num_hyps_per_beam: Beam size. additional_source_info: a `.NestedMap` of tensors containing extra context information about the source that may be useful for decoding. Returns: A tuple (results, out_states). results: A `.NestedMap` of beam search results. atten_probs: The updated attention probs, of shape [tgt_batch, src_len]. log_probs: Log prob for each of the tokens in the target vocab. This is of shape [tgt_batch, vocab_size]. out_states: A `.NestedMap`. The updated states. rnn_states: Last state of the RNN. atten_context: Updated attention context vector. atten_states: Updates attention states. """ p = self.params # additional_source_info is currently not used. del additional_source_info prev_rnn_states = states['rnn_states'] prev_atten_context = states['atten_context'] prev_atten_probs = states['atten_probs'] prev_atten_states = states['atten_states'] step_paddings = tf.zeros(py_utils.GetShape(step_ids), dtype=p.dtype) embs = self.emb.EmbLookup(theta.emb, tf.reshape(step_ids, [-1])) embs = self.ApplyClipping(theta, embs) atten_context, atten_probs, rnn_states, step_out, atten_states = ( self._DecodeStep(theta, embs, step_paddings, prev_atten_context, prev_rnn_states, prev_atten_states)) atten_probs = tf.reshape(atten_probs, tf.shape(prev_atten_probs)) logits = self.softmax.Logits(theta.softmax, [step_out]) log_probs = self.fns.qlogsoftmax( logits, qmin=p.qlogsoftmax_range_min, qmax=0.0) if p.use_prev_atten_ctx: cur_atten_probs = prev_atten_probs else: cur_atten_probs = atten_probs bs_results = py_utils.NestedMap({ 'atten_probs': cur_atten_probs, # the probs exposed to beam search 'log_probs': log_probs, }) new_states = py_utils.NestedMap({ 'rnn_states': rnn_states, 'atten_context': atten_context, 'atten_probs': atten_probs, # the updated attention probs 'atten_states': atten_states, }) return bs_results, new_states
def FPropFullSequence(self, theta, ids, paddings): return self.FProp(theta, py_utils.NestedMap(ids=ids, paddings=paddings))['encoded']
def FProp(self, theta, input_batch): """Embeds source ids and transforms with TransformerStack. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. input_batch: A `.NestedMap` object containing: ids - The inputs tensor of shape [batch, time]. paddings - The ids' paddings of shape [batch, time]. Returns: A '.NestedMap' object containing: encoded - The encoded features of shape [time, batch, dim] or [batch, time, dim], depending p.output_data_format. padding - The encoded features' padding of shape [time, batch] or [batch, time]. segment_id - The segmentation of packed inputs of shape [time, batch] or [batch, time] if it is supported by the model, or None otherwise. embedded_inputs - The embedded inputs tokens without positional encodings of shape [time, batch, dim] or [batch, time, dim]. """ p = self.params with tf.name_scope(p.name): # [batch, time] input_ids = input_batch.ids # [batch, time] paddings = input_batch.paddings # [batch, time] segment_ids = input_batch.segment_ids if p.packed_input else None batch = py_utils.GetShape(input_ids)[0] time = py_utils.GetShape(input_ids)[1] # Embedding layer. # [batch, time, dim] if not p.shared_emb: input_embs = self.token_emb.EmbLookup(theta.token_emb, input_ids) else: input_embs = self.softmax.EmbLookup(theta.softmax, input_ids) orig_input_embs = input_embs # [1, time, dim] if p.packed_input: positions = input_batch.segment_pos position_embs = tf.expand_dims( self.position_emb.FPropWithPosition(theta.position_emb, positions), 0) else: position_embs = tf.expand_dims( self.position_emb.FProp(theta.position_emb, time), 0) # [batch, time, dim] input_embs += position_embs if p.input_dropout_tpl.fprop_dtype: input_embs = tf.cast(input_embs, p.input_dropout_tpl.fprop_dtype) paddings = tf.cast(paddings, p.input_dropout_tpl.fprop_dtype) input_embs = self.input_dropout.FProp(theta.input_dropout, input_embs) # [batch, time, dim] transformer_input = input_embs # Explicitly set the input shape of Transformer layers, to avoid # unknown shape error occurred to tf.einsum on nonTPU devices. transformer_input = tf.reshape(transformer_input, [batch, time, p.model_dim]) # Compute self-attention segment mask once. if p.packed_input: segment_mask = batch_major_attention.SegmentMask( segment_ids, segment_ids, dtype=transformer_input.dtype) else: segment_mask = tf.zeros([batch, 1, time, time]) shape = py_utils.GetShape(transformer_input) batch_size = shape[0] seq_len = shape[1] paddings = tf.reshape(paddings, [batch_size, seq_len]) encoded, padding = self.transformer_stack.FProp(theta.transformer_stack, transformer_input, paddings, segment_mask) if p.final_layer_norm: encoded = self.final_ln.FProp(theta.final_ln, encoded) seq_lengths = tf.cast(tf.reduce_sum(1. - padding, axis=1), tf.int32) if p.output_data_format == 'TBC': encoded = tf.transpose(encoded, [1, 0, 2]) # [time, batch, dim] padding = tf.transpose(padding) # [time, batch] segment_ids = tf.transpose(segment_ids) if p.packed_input else None orig_input_embs = tf.transpose(orig_input_embs, [1, 0, 2]) return py_utils.NestedMap( encoded=encoded, padding=padding, seq_lengths=seq_lengths, # used by beam_search_helper. segment_id=segment_ids, embedded_inputs=orig_input_embs)
def MulSumFnMeta(x): return py_utils.NestedMap(flops=2, out_shapes=(x, ))
def AddFnMeta(x, y): del y return py_utils.NestedMap(flops=2, out_shapes=(x, ))
def _FnMeta(*shapes): return py_utils.NestedMap(flops=1, out_shapes=shapes)
def testGraphTensors(self): graph_tensors = layers.GraphTensors() graph_tensors.StoreTensor( 't', py_utils.NestedMap(a=py_utils.NestedMap(b='c'))) self.assertEqual('c', graph_tensors.GetTensor('t.a.b'))
def FPropMeta(cls, p, inputs): py_utils.CheckShapes( tuple(inputs.Filter(lambda x: x is not None).Flatten())) return py_utils.NestedMap(flops=1, out_shapes=(inputs, ))
def ReverseAndGrad(self, theta, outputs, d_outputs, f_seed, g_seed, *extra_inputs): """Implements Algorithm 1 in the revnet paper. Args: theta: A NestedMap object containing weights' values of this layer and its children layers. outputs: A NestedMap: .split1 and .split2 corresponding to y1 and y2. d_outputs: A NestedMap: .split1 and .split2 corresponding to dy1 and dy2, the total derivatives. f_seed: Scalar tensor. The step seed used in forward for the f block. g_seed: Scalar tensor. The step seed used in forward for the g block. The step seeds are needed for deterministic randomness, e.g. to ensure dropout generate the same random mask in forward and reverse_grad. *extra_inputs: additional inputs that will be passed to both f and g. No gradient will be computed for these inputs. Returns: A tuple of NestedMaps - inputs: .split1 and .split2 corresponding to x1 and x2. - d_inputs: .split1 and .split2 corresponding to dx1 and dx2, the total derivatives with respect to inputs. - d_theta: has the same structure as theta. The total derivatives with respect to weights. """ # Stop gradient on the outputs to avoid circular symbolic dependency. y1 = tf.stop_gradient(outputs.split1) y2 = tf.stop_gradient(outputs.split2) dy1 = d_outputs.split1 dy2 = d_outputs.split2 # Computes the reverse. z1 = y1 py_utils.ResetStepSeed(g_seed) gz1 = self.g_block.FProp(theta.g_block, z1, *extra_inputs) x2 = y2 - gz1 py_utils.ResetStepSeed(f_seed) fx2 = self.f_block.FProp(theta.f_block, x2, *extra_inputs) x1 = z1 - fx2 # Computes the gradients. dz1 = dy1 + tf.gradients(gz1, z1, dy2)[0] dx2 = dy2 + tf.gradients(fx2, x2, dz1)[0] dgw = tf.gradients( gz1, theta.g_block.Flatten(), dy2, unconnected_gradients=tf.UnconnectedGradients.ZERO) dgw = theta.g_block.Pack(dgw) dfw = tf.gradients( fx2, theta.f_block.Flatten(), dz1, unconnected_gradients=tf.UnconnectedGradients.ZERO) dfw = theta.f_block.Pack(dfw) return (py_utils.NestedMap(split1=x1, split2=x2), py_utils.NestedMap(split1=dz1, split2=dx2), py_utils.NestedMap( f_block=dfw, g_block=dgw, global_step=tf.zeros_like(theta.global_step)))
def InitBeamSearchCallBack(unused_theta, unused_encoder_outputs, unused_num_hyps_per_beam): return py_utils.NestedMap( log_probs=tf.zeros([tgt_batch_size, vocab_size]), atten_probs=tf.zeros([tgt_batch_size, 0])), py_utils.NestedMap()
def _OutfeedEnqueue(self, per_example_tensors): if not per_example_tensors: return tf.no_op() per_example_tensors = py_utils.NestedMap(per_example_tensors) return tpu_ops.outfeed_enqueue_tuple(per_example_tensors.Flatten())
def GetBeamSearchHelperResults(sess, num_hyps_per_beam, pass_seq_lengths=False, force_eos_in_top_k=False): np.random.seed(9384758) tf.random.set_seed(8274758) vocab_size = 12 src_len = 5 tgt_len = 7 src_batch_size = 2 tgt_batch_size = src_batch_size * num_hyps_per_beam p = beam_search_helper.BeamSearchHelper.Params().Set( name='bsh', target_seq_len=tgt_len, force_eos_in_top_k=force_eos_in_top_k) bs_helper = p.Instantiate() def InitBeamSearchState(unused_theta, unused_encoder_outputs, unused_num_hyps_per_beam): atten_probs = tf.constant(np.random.normal(size=(tgt_batch_size, src_len)), dtype=tf.float32) return (py_utils.NestedMap({ 'log_probs': tf.zeros([tgt_batch_size, vocab_size]), 'atten_probs': atten_probs, }), py_utils.NestedMap({'atten_probs': atten_probs})) def PreBeamSearchStepCallback(unused_theta, unused_encoder_outputs, unused_step_ids, states, unused_num_hyps_per_beam): atten_probs = tf.identity(states.atten_probs) logits = tf.random.normal([tgt_batch_size, vocab_size], seed=8273747) return (py_utils.NestedMap({ 'atten_probs': atten_probs, 'log_probs': logits }), states) def PostBeamSearchStepCallback(unused_theta, unused_encoder_outputs, unused_new_step_ids, states): return states src_enc = tf.random.normal([src_len, src_batch_size, 8], seed=982774838) src_enc_padding = tf.constant( [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 1.0], [1.0, 1.0]], dtype=tf.float32) encoder_outputs = py_utils.NestedMap(encoded=src_enc, padding=src_enc_padding) if pass_seq_lengths: encoder_outputs['seq_lengths'] = tf.constant([4, 3], dtype=tf.int32) theta = py_utils.NestedMap() decoder_output = bs_helper.BeamSearchDecode(theta, encoder_outputs, num_hyps_per_beam, InitBeamSearchState, PreBeamSearchStepCallback, PostBeamSearchStepCallback) topk_ids, topk_lens, topk_scores = sess.run([ decoder_output.topk_ids, decoder_output.topk_lens, decoder_output.topk_scores ]) return topk_ids, topk_lens, topk_scores
def _PreBeamSearchStepCallback(self, theta, source_encs, source_paddings, step_ids, states, num_hyps_per_beam, additional_source_info=None): """Returns logits for sampling ids and the next model states. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. source_encs: A tensor of shape [src_len, src_batch, source_dim]. Can be [time, batch, depth, num_layers] if is_transparent is set. source_paddings: A tensor of shape [src_len, src_batch]. step_ids: A tensor of shape [tgt_batch, 1]. states: A `.NestedMap` of tensors representing states that the clients would like to keep track of for each of the active hyps. num_hyps_per_beam: Beam size. additional_source_info: a `.NestedMap` of tensors containing extra context information about the source that may be useful for decoding. Returns: A tuple (results, out_states). results: A `.NestedMap` of beam search results. atten_probs: The updated attention probs, of shape [tgt_batch, src_len]. log_probs: Log prob for each of the tokens in the target vocab. This is of shape [tgt_batch, vocab_size]. out_states: A `.NestedMap`. The updated states. source_encs: A tensor of shape [src_batch, src_len, source_dim]. source_paddings: A tensor of shape [src_batch, src_len]. target_ids: Updated list of decoded ids. [num_hyps, Num of decoded ids]. """ p = self.params # additional_source_info is currently not used. del additional_source_info target_time = states.time_step prefix_states = states.prefix_states new_states = states.Pack(states.Flatten()) layer_out, updated_prefix_states = self.ExtendStep( theta, source_encs, source_paddings, tf.squeeze(step_ids, 1), target_time[0][0], prefix_states) new_states.prefix_states = updated_prefix_states new_states.time_step = target_time + 1 softmax_input = tf.reshape(layer_out, [-1, p.softmax.input_dim]) logits = self.softmax.Logits(theta.softmax, [softmax_input]) num_hyps = py_utils.GetShape(step_ids)[0] source_len = py_utils.GetShape(source_encs)[0] # [time * batch, num_classes] -> [time, batch, num_classes] logits = tf.reshape(logits, (-1, num_hyps, p.softmax.num_classes)) # [time, batch, num_classes] -> [batch, time, num_classes] logits = tf.transpose(logits, (1, 0, 2)) # Dummy attention probs atten_probs = tf.ones([num_hyps, source_len]) / tf.to_float(source_len) # Only return logits for the last ids log_probs = tf.nn.log_softmax(tf.squeeze(logits, axis=1)) bs_results = py_utils.NestedMap({ 'atten_probs': atten_probs, 'log_probs': log_probs, }) return bs_results, new_states
def testBeamSearchForceLastChunkEocInTopK(self, is_last_chunk, force_last_chunk_eoc_in_top_k, eos_score, expected_topk_lens, expected_topk_scores): with self.session() as sess: vocab_size = 30 tgt_len = 10 num_hyps_per_beam = 3 src_batch_size = 2 tgt_batch_size = src_batch_size * num_hyps_per_beam p = beam_search_helper.BeamSearchHelper.Params().Set( name='bsh', target_eoc_id=0, target_seq_len=tgt_len, num_hyps_per_beam=num_hyps_per_beam, beam_size=100000.0, # Beam search until the end. force_last_chunk_eoc_in_top_k=force_last_chunk_eoc_in_top_k, ) bs_helper = p.Instantiate() def InitBeamSearchCallBack(unused_theta, unused_encoder_outputs, unused_num_hyps_per_beam): return py_utils.NestedMap( log_probs=tf.zeros([tgt_batch_size, vocab_size]), atten_probs=tf.zeros([tgt_batch_size, 0]), is_last_chunk=tf.zeros([tgt_batch_size], tf.bool)), py_utils.NestedMap() def PreBeamSearchStepCallback(unused_theta, unused_encoder_outputs, unused_step_ids, states, unused_num_hyps_per_beam): # Same probs for each id. logits = tf.zeros([tgt_batch_size, vocab_size]) # Except eoc has slightly lower score. logits = logits - 1.0 * tf.expand_dims( tf.one_hot(p.target_eoc_id, vocab_size), 0) # eos has very low score (can not terminate by eos) logits = logits + eos_score * tf.expand_dims( tf.one_hot(p.target_eos_id, vocab_size), 0) return py_utils.NestedMap( atten_probs=tf.zeros([tgt_batch_size, 0]), log_probs=logits, is_last_chunk=tf.fill([tgt_batch_size], value=is_last_chunk)), states def PostBeamSearchStepCallback(unused_theta, unused_encoder_outputs, unused_new_step_ids, states): return states encoder_outputs = py_utils.NestedMap( seq_lengths=tf.zeros([src_batch_size], dtype=tf.int32)) theta = py_utils.NestedMap() beam_search_output = bs_helper.BeamSearchDecode( theta, encoder_outputs, init_beam_search_state=InitBeamSearchCallBack, pre_beam_search_step_callback=PreBeamSearchStepCallback, post_beam_search_step_callback=PostBeamSearchStepCallback) topk_lens, topk_scores = sess.run( [beam_search_output.topk_lens, beam_search_output.topk_scores]) self.assertAllEqual(topk_lens, expected_topk_lens) self.assertAllClose(topk_scores, expected_topk_scores, atol=1e-6)
def __init__(self, params): """Layer constructor. Args: params: A params used to construct this layer. """ assert params.name, ('Layer params for %s must have a "name"' % self.__class__.__name__) tf_module_name = params.name tf_module_name = re.sub('[^a-zA-Z0-9_]+', '_', tf_module_name) tf_module_name = 'bbf_' + self.__class__.__name__ + '_' + tf_module_name py_utils.NestedMap.CheckKey(tf_module_name) # initialize the base class. super().__init__(tf_module_name) # Note AutoTracking doesn't work properly due to its inability to walk # through py_utils.NestedMap data structures which are used widely # throughout the Lingvo codebase. Also there seems to be some performance # hit in turning on auto-tracking in constructing graphs. For now, we # disable auto-tracking. # TODO(lingvo): Re-enable auto-tracking when fuller support is # added for key data structures used in Lingvo, and performance issue is # debugged more and understood better. self._setattr_tracking = False self._parent = None for parent in reversed(_LAYER_STACK.stack): if parent is not self: self._parent = parent break self._params = params.Copy() tf.logging.debug('Creating layer %s with params: \n %s \n', self.__class__.__name__, str(params)) # Vars created by this layer. self._private_vars = py_utils.NestedMap() # Theta derived from this layer's vars. self._private_theta = py_utils.NestedMap() # Child layers created by this layer through CreateChild/CreateChildren. self._private_children = py_utils.NestedMap() # Child layers created by this layer. A well-formed layer should # have self._private_children equals to self._children_list. I.e., # all child layers are created using CreateChild/CreateChildren. self._children_list = [] # Extra theta's not directly correspond to any underlying vars. For example, # the concatenated sharded variables. self._extra_theta = py_utils.NestedMap() # All registered accumulators. self._private_accumulators = py_utils.NestedMap() # Layer-private functions. Add with AddFunction. self._private_fns = dict() # Mapping from variable names to its symbolic shape. # self._var_symbolic_shape_map['var_name'] will be a tuple of integers or # symbolic expressions, one for each dimension of the variable. self._var_symbolic_shape_map = {} self._is_variable_free = False self._variables_to_create = {} self._create_variables_status = _CreateLayerVariablesStatus.NOT_CALLED # Keep track of the tf.variable_scope(p.name) this layer creates so we can # reenter it without creating a new one. self._self_variable_scope = None
def testCustomStepIds(self): with self.session(use_gpu=False): np.random.seed(9384758) tf.random.set_seed(8274758) vocab_size = 12 src_len = 5 tgt_len = 7 num_hyps_per_beam = 3 src_batch_size = 2 tgt_batch_size = src_batch_size * num_hyps_per_beam p = beam_search_helper.BeamSearchHelper.Params().Set( name='bsh', target_seq_len=tgt_len) bs_helper = p.Instantiate() def InitBeamSearchState(unused_theta, unused_encoder_outputs, unused_num_hyps_per_beam): atten_probs = tf.constant( np.random.normal(size=(tgt_batch_size, src_len)), dtype=tf.float32) return (py_utils.NestedMap({ 'log_probs': tf.zeros([tgt_batch_size, vocab_size]), 'atten_probs': atten_probs, 'step_ids': tf.zeros([tgt_batch_size, 1], dtype=tf.int32) }), py_utils.NestedMap({'atten_probs': atten_probs})) def PreBeamSearchStepCallback(unused_theta, unused_encoder_outputs, unused_step_ids, states, unused_num_hyps_per_beam): atten_probs = tf.identity(states.atten_probs) logits = tf.random.normal([tgt_batch_size, vocab_size], seed=8273747) return (py_utils.NestedMap({ 'atten_probs': atten_probs, 'log_probs': logits }), states) def PostBeamSearchStepCallback(unused_theta, unused_encoder_outputs, unused_new_step_ids, states): return states src_enc = tf.random.normal([src_len, src_batch_size, 8], seed=982774838) src_enc_padding = tf.constant( [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 1.0], [1.0, 1.0]], dtype=tf.float32) encoder_outputs = py_utils.NestedMap(encoded=src_enc, padding=src_enc_padding) theta = py_utils.NestedMap() decoder_output = bs_helper.BeamSearchDecode( theta, encoder_outputs, num_hyps_per_beam, InitBeamSearchState, PreBeamSearchStepCallback, PostBeamSearchStepCallback) topk_ids, topk_lens, topk_scores = self.evaluate([ decoder_output.topk_ids, decoder_output.topk_lens, decoder_output.topk_scores ]) print(np.array_repr(topk_ids)) print(np.array_repr(topk_lens)) print(np.array_repr(topk_scores)) expected_topk_ids = [[4, 3, 4, 3, 2, 0, 0], [4, 3, 11, 2, 0, 0, 0], [4, 3, 6, 2, 0, 0, 0], [6, 0, 4, 6, 6, 11, 2], [6, 0, 4, 6, 1, 2, 0], [6, 0, 4, 6, 6, 2, 0]] expected_topk_lens = [5, 4, 4, 7, 6, 6] expected_topk_scores = [[8.27340603, 6.26949024, 5.59490776], [9.74691486, 8.46679497, 7.14809656]] self.assertAllEqual(expected_topk_ids, topk_ids.tolist()) self.assertAllEqual(expected_topk_lens, topk_lens.tolist()) self.assertAllClose(expected_topk_scores, topk_scores)
def zero_state(self, theta, batch_size): return py_utils.NestedMap(rnn=[ self.rnn[i].zero_state(theta.rnn[i], batch_size) for i in range(len(self.rnn)) ])
def testGreedySearchHelper(self): with self.session(use_gpu=False): np.random.seed(9384758) tf.random.set_seed(8274758) vocab_size = 12 src_len = 5 tgt_len = 7 src_batch_size = 2 tgt_batch_size = src_batch_size p = beam_search_helper.GreedySearchHelper.Params().Set( name='gsh', target_seq_len=tgt_len) gs_helper = p.Instantiate() def InitGreedySearchState(unused_theta, unused_encoder_outputs, unused_num_hyps_per_beam): atten_probs = tf.constant( np.random.normal(size=(tgt_batch_size, src_len)), dtype=tf.float32) return (py_utils.NestedMap({ 'log_probs': tf.zeros([tgt_batch_size, vocab_size]), 'atten_probs': atten_probs, }), py_utils.NestedMap({'atten_probs': atten_probs})) def PreGreedySearchStepCallback(unused_theta, unused_encoder_outputs, unused_step_ids, states, unused_num_hyps_per_beam): atten_probs = tf.identity(states.atten_probs) logits = tf.random.normal([tgt_batch_size, vocab_size], seed=8273747) return (py_utils.NestedMap({ 'atten_probs': atten_probs, 'log_probs': logits }), states) def PostGreedySearchStepCallback(unused_theta, unused_encoder_outputs, unused_new_step_ids, states): return states src_enc = tf.random.normal([src_len, src_batch_size, 8], seed=982774838) src_enc_padding = tf.constant( [[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 1.0], [1.0, 1.0]], dtype=tf.float32) encoder_outputs = py_utils.NestedMap(encoded=src_enc, padding=src_enc_padding) theta = py_utils.NestedMap() (final_hyp_ids, final_hyp_lens, final_done_hyps) = gs_helper.GreedySearchDecode( theta, encoder_outputs, InitGreedySearchState, PreGreedySearchStepCallback, PostGreedySearchStepCallback) (final_hyp_ids, final_hyp_lens, final_done_hyps) = self.evaluate( [final_hyp_ids, final_hyp_lens, final_done_hyps]) print(np.array_repr(final_hyp_ids)) print(np.array_repr(final_hyp_lens)) print(np.array_repr(final_done_hyps)) expected_hyp_ids = [[2, 2, 6, 7, 1, 9, 4], [3, 9, 3, 9, 6, 5, 10]] expected_hyp_lens = [1, 7] expected_done_hyps = [True, False] self.assertAllEqual(expected_hyp_ids, final_hyp_ids.tolist()) self.assertAllEqual(expected_hyp_lens, final_hyp_lens.tolist()) self.assertAllEqual(expected_done_hyps, final_done_hyps.tolist())
def FPropMeta(cls, p, inputs, *args): dim1, dim2 = args[1][:2] if p.inputs_from_decoder else inputs[:2] logits = tshape.Shape([dim1, dim2, p.num_classes]) return py_utils.NestedMap(flops=100, out_shapes=(logits, ))
def BuildInputBatch(self, batch_size, features_list, bucket_keys=None): """Builds an input batch. Args: batch_size: batch size to use, defaults to infeed batch size. features_list: Use this list to build the batch. bucket_keys: If None, bucket_keys[i] is the bucketing key of the i-th sample. Returns: py_utils.NestedMap with feature names as keys and tensors as values. """ p = self.params ret = py_utils.NestedMap() ret.bucket_keys = bucket_keys (src_ids, src_paddings, tgt_ids, tgt_paddings, tgt_labels, tgt_weights) = features_list if p.pad_to_max_seq_length: assert p.source_max_length if min(self.infeed_bucket_batch_limit) == max( self.infeed_bucket_batch_limit): source_shape = [ min(self.infeed_bucket_batch_limit), p.source_max_length ] target_shape = [ min(self.infeed_bucket_batch_limit), p.target_max_length ] else: source_shape = None target_shape = None src_ids = py_utils.PadSequenceDimension(src_ids, p.source_max_length, 0, source_shape) src_paddings = py_utils.PadSequenceDimension(src_paddings, p.source_max_length, 1, source_shape) tgt_ids = py_utils.PadSequenceDimension(tgt_ids, p.target_max_length, 0, target_shape) tgt_paddings = py_utils.PadSequenceDimension(tgt_paddings, p.target_max_length, 1, target_shape) tgt_labels = py_utils.PadSequenceDimension(tgt_labels, p.target_max_length, 0, target_shape) tgt_weights = py_utils.PadSequenceDimension(tgt_weights, p.target_max_length, 0, target_shape) ret.src = py_utils.NestedMap() ret.src.ids = tf.cast(src_ids, dtype=tf.int32) ret.src.paddings = src_paddings ret.tgt = py_utils.NestedMap() ret.tgt.ids = tgt_ids ret.tgt.labels = tf.cast(tgt_labels, dtype=tf.int32) ret.tgt.weights = tgt_weights ret.tgt.paddings = tgt_paddings if (self.params.fprop_dtype is None or self.params.dtype == self.params.fprop_dtype): return ret def _Cast(v): if not v.dtype.is_floating: return v return tf.cast(v, self.params.fprop_dtype) return ret.Transform(_Cast)
def Sample(self, decoder_theta, source_encs, source_paddings, random_seed, init_state_callback, pre_step_callback, post_step_callback): """Samples target sequences, one target sequence per source sequence. (Please see beam_search_helper.py for description of decoder callbacks.) Args: decoder_theta: A NestedMap object containing weights' values of the decoder layer and its children layers, to be passed to decoder callbacks. source_encs: source encoding, to be passed to decoder callbacks. source_paddings: source padding, to be passed to decoder callbacks. random_seed: a scalar int32 tensor representing the random seed. init_state_callback: decoder._InitBeamSearchStateCallback. pre_step_callback: decoder._PreBeamSearchStepCallback. post_step_callback: decoder._PostBeamSearchStepCallback. Returns: A NestedMap containing the following tensors: - 'logits': [batch, max_target_length, vocab_size], representing the distribution from which target sequences are sampled. - 'ids': [batch, max_target_length] of int32, representing the target sequence ids, not including target_sos_id, but maybe ending with target_eos_id if end-of-sequence is reached before target_seq_len. - 'paddings': [batch, max_target_length] of 0/1, where 1 represents a padded timestep. """ p = self.params assert p.temperature > 0 recurrent_theta = py_utils.NestedMap(theta=decoder_theta, random_seed=random_seed, source_encs=source_encs, source_paddings=source_paddings) bs_result, bs_state = init_state_callback(recurrent_theta.theta, source_encs, source_paddings, num_hyps_per_beam=1) batch = tf.shape(bs_result.log_probs)[0] recurrent_state0 = py_utils.NestedMap( timestep=tf.zeros(shape=[], dtype=tf.int32), logits=bs_result.log_probs, # Start with target_sos_id. ids=tf.fill([batch], tf.to_int32(p.target_sos_id)), bs_state=bs_state) inputs = py_utils.NestedMap(dummy=tf.zeros([p.target_seq_len, batch])) def Step(recurrent_theta, state0, inputs): """Computes one decoder step.""" del inputs with tf.name_scope('single_sampler_step'): # Compute logits and states. bs_result, bs_state1 = pre_step_callback( recurrent_theta.theta, recurrent_theta.source_encs, recurrent_theta.source_paddings, tf.expand_dims(state0.ids, 1), # [batch, 1]. state0.bs_state, num_hyps_per_beam=1) batch = tf.shape(bs_result.log_probs)[0] state1 = py_utils.NestedMap(timestep=state0.timestep + 1) state1.logits = bs_result.log_probs # Sample ids from logits. [batch]. state1.ids = tf.reshape( tf.contrib.stateless.stateless_multinomial( state1.logits / p.temperature, num_samples=1, seed=tf.stack( [recurrent_theta.random_seed, state0.timestep]), output_dtype=state0.ids.dtype, name='sample_next_id'), [batch]) if 'is_last_chunk' in bs_result and p.target_eoc_id >= 0: state1.ids = tf.where( tf.logical_and(bs_result.is_last_chunk, tf.equal(state1.ids, p.target_eoc_id)), tf.fill(tf.shape(state1.ids), p.target_eos_id), state1.ids) state1.bs_state = post_step_callback( recurrent_theta.theta, recurrent_theta.source_encs, recurrent_theta.source_paddings, state1.ids, bs_state1) return state1, py_utils.NestedMap() accumulated_states, _ = recurrent.Recurrent(recurrent_theta, recurrent_state0, inputs, Step) result = py_utils.NestedMap(logits=tf.transpose( accumulated_states.logits, [1, 0, 2]), ids=tf.transpose(accumulated_states.ids)) result.paddings = tf.cast( _ComputePaddings(result.ids, p.target_eos_id), result.logits.dtype) # Force ids to be eos_id if the timestep is padded. result.ids = tf.where(tf.equal(result.paddings, 0), result.ids, tf.fill(tf.shape(result.ids), p.target_eos_id)) static_batch_size = bs_result.log_probs.shape[0] result.ids.set_shape([static_batch_size, p.target_seq_len]) result.paddings.set_shape([static_batch_size, p.target_seq_len]) return result
def CellFn(theta, state0, unused_inputs): out_nmap = self._FProp(theta, state0) return out_nmap, py_utils.NestedMap()
def FPropMeta(cls, p, inputs, padding=None): py_utils.CheckShapes((inputs, )) return py_utils.NestedMap(flops=inputs.num_elements() * _BN_FLOPS_PER_ELEMENT, out_shapes=(inputs, ))
def FPropMeta(cls, p, inputs): py_utils.CheckShapes((inputs, )) return py_utils.NestedMap(flops=1, out_shapes=(inputs, ))