def _BatchScatter(default_tensor, indices, values): """Performs tf.tensor_scatter_nd_update for each batch item. Args: default_tensor: A float tensor of shape [batch, vocab] that contains the default values. indices: An int tensor of shape [batch, k] that represents the k indices of `default_tensor` to update. values: A float tensor of shape [batch, k] that represents the value to replace with for each corresponding element of `indices`. Returns: A tensor like `default_tensor` where the (i, indices[i][j]) element has been replaced with values[i][j]. """ batch_size = tf.shape(default_tensor)[0] # Prepend batch indices to `indices`. batch_indices = tf.range(batch_size, dtype=indices.dtype) batch_indices = tf.expand_dims(batch_indices, 1) batch_indices = tf.broadcast_to(batch_indices, tf.shape(indices)) batch_indices = tf.stack([batch_indices, indices], axis=2) return tf.tensor_scatter_nd_update(default_tensor, batch_indices, values)
def FProp(self, theta, inputs, paddings): p = self.params fns = self.fns # It is the most important that weights and top-level activations # be tagged for quantization: # - Weights use the self.QWeight() decorator # - Inputs/activations are decorated with self.QTensor(). In general, # the provided name should match a call to self.TrackQTensor in the # constructor. This creates an tensor that is individually accounted # for. w = fns.qweight(theta.w) inputs = self.QTensor('inputs', inputs) reshaped_inputs = tf.reshape(inputs, [-1, p.input_dim]) reshaped_inputs, w = self.ToAqtInputs( 'aqt_w', act=reshaped_inputs, weight=w, w_feature_axis=-1, w_expected_scale_shape=(1, p.output_dim)) # Note the use of the qmatmul from the function library. This will # automatically track the output against the qtensor 'transformed'. out = fns.qmatmul(reshaped_inputs, w, qt='transformed') out = self.FromAqtMatmul('aqt_w', out) out = tf.reshape(out, tf.concat([tf.shape(inputs)[:-1], [p.output_dim]], 0)) # Decorate outputs of simple activation functions with their corresponding # range decorator. This will ensure that the result does not exceed the # precision of the underlying representation. out = fns.qtanh(out) # Perform padding manipulation via booleans instead of: # out *= 1.0 - paddings # Because the paddings can exist in entirely different numeric ranges than # the tensor they are being applied to, it is best to not perform # arithmetic directly between them. Instead, broadcast them to the needed # size (if different) and perform an exact mask with tf.where. # For added numeric range protection, the QRPadding decorator ensures # the correct range. This is mostly needed for cases where padding is # dynamic at inference time. paddings = self.QRPadding(paddings) paddings *= tf.ones_like(out) # Broadcast to 'out' size. out = tf.where(paddings > 0.0, tf.zeros_like(out), out) return out
def FPropTower(self, theta, input_batch): p = self.params tf.logging.info('input_batch=%r', input_batch) ids, paddings, labels_ids, weights = self._TrimIfPossible( input_batch.ids, input_batch.paddings, input_batch.labels, input_batch.weights) fprop_dtype = py_utils.FPropDtype(p) paddings = tf.cast(paddings, fprop_dtype) weights = tf.cast(weights, fprop_dtype) tf.logging.info('inputs={}'.format( (ids, paddings, labels_ids, weights))) batch_size = tf.shape(ids)[0] state0 = None labels = py_utils.NestedMap(class_ids=labels_ids, class_weights=weights) fprop_kwargs = dict() if 'segment_ids' in input_batch: fprop_kwargs.update(segment_ids=input_batch.segment_ids, segment_pos=input_batch.segment_pos) xent_output, _ = self.lm.FProp(theta.lm, ids, paddings, state0, labels, **fprop_kwargs) if 'segment_ids' in input_batch: num_sentences = input_batch.num_sentences else: num_sentences = tf.ones(shape=[batch_size], dtype=tf.int32) # +num_sentences to account for the end of sequence symbol. num_words = tf.cast( tf.reduce_sum(input_batch.word_count + num_sentences), fprop_dtype) predicted_labels = tf.cast(xent_output.per_example_argmax, labels_ids.dtype) num_preds = xent_output.total_weight mean_acc = tf.reduce_sum( tf.cast(tf.equal(labels_ids, predicted_labels), fprop_dtype) * weights) / tf.math.maximum(num_preds, 1) loss = xent_output.avg_xent if p.train.sum_loss_across_tokens_in_batch: loss = xent_output.total_xent return { 'loss': (loss, num_preds), 'fraction_of_correct_next_step_preds': (mean_acc, num_preds), 'log_pplx': (xent_output.avg_xent, num_preds), 'log_pplx_per_word': (xent_output.total_xent / num_words, num_words), 'num_predictions': (num_preds, 1), 'num_words': (num_words, 1), 'num_sentences': (tf.reduce_sum(num_sentences), 1), }, {}
def Proc(record): """Parses a serialized tf.Example record.""" features = [ ('uttid', tf.VarLenFeature(tf.string)), ('transcript', tf.VarLenFeature(tf.string)), ('frames', tf.VarLenFeature(tf.float32)), ] example = tf.parse_single_example(record, dict(features)) fval = {k: v.values for k, v in six.iteritems(example)} # Reshape the flattened vector into its original time-major # representation. fval['frames'] = tf.reshape(fval['frames'], shape=[-1, self.params.frame_size]) # Input duration determines the bucket. bucket_key = tf.to_int32(tf.shape(fval['frames'])[0]) if self.params.append_eos_frame: bucket_key += 1 tgt_ids, tgt_labels, tgt_paddings = self.StringsToIds( fval['transcript']) src_paddings = tf.zeros([tf.shape(fval['frames'])[0]], dtype=tf.float32) return fval['uttid'], tgt_ids, tgt_labels, tgt_paddings, fval[ 'frames'], src_paddings, bucket_key
def _MaybeExpandPaddings(self, inputs, paddings): # rank difference is at most one. rank_diff = tf.rank(inputs) - tf.rank(paddings) paddings = py_utils.with_dependencies([ py_utils.assert_less_equal(rank_diff, 1), py_utils.assert_greater_equal(rank_diff, 0) ], paddings) # Pads [1] to the end of paddings. paddings = tf.reshape( paddings, tf.concat([tf.shape(paddings), tf.tile([1], [rank_diff])], axis=0)) return paddings
def GetEncoderEmbeddingsDefaultTheta(self, input_ids, task_ids=None): p = self.params time_dim = 0 if p.batch_dim else 1 seq_len = tf.shape(input_ids)[time_dim] input_embs = self.src_token_emb.EmbLookup(self.theta.src_token_emb, input_ids) pos_embs = tf.expand_dims( self.src_pos_emb.FProp(self.theta.src_pos_emb, seq_len), p.batch_dim) input_embs += pos_embs if task_ids is not None and p.enc_task_emb: input_embs += self.src_task_emb.EmbLookup(self.theta.src_task_emb, task_ids) input_embs = self.src_dropout.FProp(self.theta.src_dropout, input_embs) return input_embs
def FProp(self, theta, inputs): """Applies batch normalization. Using the implementation in github.com/ tensorflow/tpu/blob/master/models/official/amoeba_net/network_utils.py#L550 Args: theta: A nested map object containing weights' values of this layer and its children layers. inputs: The inputs tensor. Shaped [..., dim]. Returns: Output after applying batch normalization, with the same shape as 'inputs'. """ p = self.params inputs_dtype = inputs.dtype inputs = tf.cast(inputs, p.dtype) inputs = py_utils.with_dependencies([ py_utils.assert_shape_match([tf.shape(inputs)[-1]], tf.shape( theta.beta)) ], inputs) with tf.name_scope(p.name) as scope: if self.do_eval: outputs = tf.nn.batch_normalization(inputs, theta.moving_mean, theta.moving_variance, theta.beta, theta.gamma, p.epsilon) else: mean, variance = self._Moments(inputs, p.bn_group_size) mean = py_utils.CheckNumerics( mean, 'mean of {} failed numeric check'.format(scope)) variance = py_utils.CheckNumerics( variance, 'variance of {} failed numeric check'.format(scope)) outputs = tf.nn.batch_normalization(inputs, mean, variance, theta.beta, theta.gamma, p.epsilon) outputs.set_shape(inputs.get_shape()) return tf.cast(outputs, inputs_dtype)
def _FrequencyMask(self, inputs, num_freq=80, dtype=tf.float32, domain_id_index=0): """Applies frequency masking with given degree to inputs. Args: inputs: Batch of input features of shape (batch_size, time_length, num_freq, channels). num_freq: Number of frequencies. dtype: Data type. domain_id_index: domain id index. Returns: Inputs with random frequency masking applied. """ # If maximum mask length is zero, do nothing p = self.params if p.freq_mask_max_bins[domain_id_index] == 0: return inputs # Choose random masked length max_length = tf.random.uniform((tf.shape(inputs)[0],), maxval=p.freq_mask_max_bins[domain_id_index], dtype=tf.int32, seed=p.random_seed) # Create masks in frequency direction and apply block_arrays = self._GetMask( tf.shape(inputs)[0], max_length, choose_range=num_freq, mask_size=num_freq, dtype=dtype) outputs = tf.einsum('bxyc,by->bxyc', inputs, block_arrays) return outputs
def ApplyBias(): """Bias and update log_probs and consistent.""" def TileForBeamAndFlatten(tensor): tensor = tf.reshape(tensor, [1, -1]) # [1, src_batch] tensor = tf.tile(tensor, [num_hyps_per_beam, 1 ]) # [num_hyps_per_beam, src_batch] tgt_batch = tf.shape(step_ids)[ 0] # num_hyps_per_beam*src_batch return tf.reshape(tensor, [tgt_batch]) # Consistent if step_ids == labels from previous step # TODO(navari): Consider updating consistent only if weights > 0. Then # re-evaluate the need for bias_only_if_consistent=True. # Note that prev_label is incorrrect for step 0 but is overridden later prev_label = TileForBeamAndFlatten( tf.gather(labels, tf.maximum(time_step - 1, 0), axis=1)) is_step0 = tf.equal(time_step, 0) local_consistence = tf.math.logical_or( is_step0, tf.equal(prev_label, tf.squeeze(step_ids, 1))) consistent = tf.math.logical_and(states.consistent, local_consistence) # get label, weight slices corresponding to current time_step label = TileForBeamAndFlatten( tf.gather(labels, time_step, axis=1)) weight = TileForBeamAndFlatten( tf.gather(weights, time_step, axis=1)) if p.bias_only_if_consistent: weight = weight * tf.cast(consistent, py_utils.FPropDtype(p)) # convert from dense label to sparse label probs vocab_size = tf.shape(bs_results.log_probs)[1] label_probs = tf.one_hot( label, vocab_size, dtype=py_utils.FPropDtype(p)) # [tgt_batch, vocab_size] pred_probs = tf.exp(bs_results.log_probs) # interpolate predicted probs and label probs weight = tf.expand_dims(weight, 1) probs = py_utils.with_dependencies([ py_utils.assert_less_equal(weight, 1.), py_utils.assert_greater_equal(weight, 0.) ], (1.0 - weight) * pred_probs + weight * label_probs) # Ensure that tf.math.log is applied to positive values. probs = tf.maximum(probs, tf.constant(1e-12, dtype=probs.dtype)) return tf.math.log(probs), consistent
def FProp(self, theta, inputs): """Apply projection to inputs. Args: theta: A NestedMap object containing weights' values of this layer and its children layers. inputs: The inputs tensor. Shaped [..., input_dims]. Returns: Projected inputs. """ p = self.params with tf.name_scope(p.name): computation_cost.Add( self, 'flops', tf.reduce_prod(tf.to_int64(tf.shape(inputs)[:-1])) * tf.to_int64( symbolic.EvalExpr(symbolic.TENSOR_VALUES, p.input_dims * p.output_dims)) * 2) use_tpu = py_utils.use_tpu() if use_tpu and inputs.shape is not None and inputs.shape.rank < 26: # Avoids reshape if feasible and uses Einsum. if inputs.shape.rank == 2: return tf.matmul(inputs, theta.w) else: s = ''.join([chr(x) for x in range(97, 123)]) # abc...xyz r = inputs.shape.rank return tf.einsum('{0}y,yz->{0}z'.format(s[:r - 1]), inputs, theta.w) input_dim = py_utils.GetShape(inputs)[-1] act = tf.matmul(tf.reshape(inputs, [-1, input_dim]), theta.w) output_dim = tf.shape(theta.w)[-1] act = tf.reshape( act, tf.concat([tf.shape(inputs)[:-1], [output_dim]], axis=0)) return act
def FProp(self, theta, source_vecs, source_paddings, target_vecs, target_paddings, source_segment_id, target_segment_id, transparent_acc, transparent_acc_helper, source_task_id=None, target_task_id=None): del source_task_id del target_task_id p = self.params if p.inputs_from_decoder: transformer_output = target_vecs else: transformer_output = source_vecs dim1, dim2 = tf.shape(transformer_output)[0], tf.shape( transformer_output)[1] softmax_input = tf.reshape(transformer_output, [-1, p.input_dim]) output_shape = [dim1, dim2, p.num_classes] return tf.reshape(super().Logits(theta, [softmax_input]), output_shape)
def SplitTensors(xs, num_splits): """Splits tensors in `xs` evenly into num_splits along the 1st dimenion. Args: xs: A tuple of tensors. Each tensor's 1st dimension is the same size. num_splits: A python integer. Returns: A tuple of lists of tensors, num elements in the tuple = len(xs). i-th element in each list corresponds to i-th split of each tensor in xs along the first dimension of each tensor. """ # assert first dim of all tensors in xs is equal batch_dims = [tf.shape(x)[0] for x in xs] all_batch_dims = tf.stack(batch_dims) all_batch_dims = py_utils.with_dependencies([ py_utils.assert_equal(all_batch_dims, tf.shape(xs[0])[0], message='first dim of tensors in xs must match'), py_utils.assert_greater_equal( tf.shape(xs[0])[0], num_splits, message='first dim of tensors in xs must be greater than num_splits' ) ], all_batch_dims) splits = ComputeSplits(tf.shape(xs[0])[0], num_splits) # add the above assertion into the compute graph splits = py_utils.with_dependencies([all_batch_dims], splits) split_xs = [ tf.split(axis=0, num_or_size_splits=splits, value=x) for x in xs ] return split_xs
def _BeamSearchDecode(self, input_batch): p = self.params with tf.name_scope('fprop'), tf.name_scope(p.name): encoder_outputs = self.enc.FPropDefaultTheta(input_batch.src) decoder_outs = self.dec.BeamSearchDecode(encoder_outputs) topk_hyps = decoder_outs.topk_hyps topk_ids = decoder_outs.topk_ids topk_lens = decoder_outs.topk_lens topk_scores = decoder_outs.topk_scores slen = tf.to_int32(tf.reduce_sum(1 - input_batch.src.paddings, 1) - 1) srcs = self.input_generator.IdsToStrings( input_batch.src.ids, slen, self._GetTokenizerKeyToUse('src')) topk_decoded = self.input_generator.IdsToStrings( topk_ids, topk_lens - 1, self._GetTokenizerKeyToUse('tgt')) topk_decoded = tf.reshape(topk_decoded, tf.shape(topk_hyps)) topk_scores = tf.reshape(topk_scores, tf.shape(topk_hyps)) refs = self.input_generator.IdsToStrings( input_batch.tgt.labels, tf.to_int32(tf.reduce_sum(1.0 - input_batch.tgt.paddings, 1) - 1.0), self._GetTokenizerKeyToUse('tgt')) ret_dict = { 'target_ids': input_batch.tgt.ids, 'target_labels': input_batch.tgt.labels, 'target_weights': input_batch.tgt.weights, 'target_paddings': input_batch.tgt.paddings, 'sources': srcs, 'targets': refs, 'topk_decoded': topk_decoded, 'topk_lens': topk_lens, 'topk_scores': topk_scores, } return ret_dict
def BeamSearchDecodeOutputToDecoderTopK(decoder_outs, *, ids_to_strings_fn, tag=''): """Converts BeamSearchDecodeOutput to DecoderTopK. As a side-effect, also creates TF nodes used by eval pipelines ("top_k_decoded" and "top_k_scores"). Args: decoder_outs: a beam_search_helper.BeamSearchDecodeOutput instance. ids_to_strings_fn: a function of (ids, lens) -> strings, where ids has shape [batch, length], lens has shape [batch], and strings has shape [batch]. tag: optional tag for tf.identity() names. Returns: A DecoderTopK instance. """ hyps = decoder_outs.topk_hyps ids = decoder_outs.topk_ids lens = tf.identity(decoder_outs.topk_lens, name='TopKLabelLengths' + tag) scores = decoder_outs.topk_scores decoded = decoder_outs.topk_decoded if decoder_outs.topk_ids is not None: ids = tf.identity(ids, name='TopKLabelIds' + tag) # With the assumption that ids[-1] is always EOS token. # TODO(b/195027707): remove EOS token in better way. decoded = ids_to_strings_fn(ids, lens - 1) decoded = tf.identity(decoded, name='top_k_decoded%s' % tag) decoded = tf.reshape(decoded, tf.shape(scores)) if scores is not None and hyps is not None: scores = tf.identity(tf.reshape(scores, tf.shape(lens)), name='top_k_scores%s' % tag) scores = tf.reshape(scores, tf.shape(hyps)) return DecoderTopK(hyps, ids, lens, scores, decoded)
def _Proc(record): """Parses a serialized tf.Example record.""" outputs = [ ('inputs', tf.VarLenFeature(tf.int64)), ('targets', tf.VarLenFeature(tf.int64)), ] features = tf.parse_single_example(record, dict(outputs)) for k, v in six.iteritems(features): features[k] = v.values src_ids = features['inputs'] tgt_labels = features['targets'] # Derive trivial segmentation for unpacked input. src_paddings, tgt_ids, tgt_paddings, tgt_weights, bucket_key = _DerivePaddingsAndIds( src_ids, tgt_labels) src_len = tf.shape(src_ids)[0] tgt_len = tf.shape(tgt_ids)[0] src_pos = tf.range(src_len, dtype=tf.int32) src_seg = tf.zeros_like(src_paddings) tgt_pos = tf.range(tgt_len, dtype=tf.int32) tgt_seg = tf.zeros_like(tgt_paddings) return [ src_ids, src_paddings, tgt_ids, tgt_paddings, tgt_labels, tgt_weights, src_pos, src_seg, tgt_pos, tgt_seg, ], bucket_key
def InitBeamSearchStateCallback(theta, encoder_outputs, num_hyps_per_beam): """Wrapper for adding bias to _InitBeamSearchStateCallback. Exapnds state to track consistency of hypothesis with provided target. Args: theta: A NestedMap object containing weights' values of this layer and its children layers. encoder_outputs: A NestedMap computed by encoder. num_hyps_per_beam: An int, number hyps to keep for source sentence. Returns: initial_results: a `.NestedMap` of initial results. states: a `.NestedMap` of initial model states that the client would like to keep track of for each hyp. The states relevant here are: time_step: A scalar indicating current step (=0 for initial state) of decoder. Must be provided and maintained by super. consistent: A boolean tensor of shape [tgt_batch, 1] which tracks whether each hypothesis has exactly matched encoder_outputs.targets so far. """ initial_results, states = self._InitBeamSearchStateCallback( theta, encoder_outputs, num_hyps_per_beam) assert hasattr(states, 'time_step') if tf.is_tensor(encoder_outputs.padding): batch_size = tf.shape(encoder_outputs.padding)[1] else: # Required for multisource models. batch_size = tf.shape(list(encoder_outputs.padding.values())[0])[1] num_hyps = batch_size * num_hyps_per_beam # states.consistent is initially all True states.consistent = tf.ones([ num_hyps, ], dtype=tf.bool) return initial_results, states
def SequenceLength(padding): """Computes the length of a sequence based on binary padding. Args: padding: A tensor of binary paddings shaped [batch, seqlen]. Returns: seq_lens, A tensor of shape [batch] containing the non-padded length of each element of plot_tensor along the batch dimension. """ seq_lens = tf.cast(tf.round(tf.reduce_sum(1 - padding, axis=1)), tf.int32) # Get rid of any extra dimensions. batch_size = tf.shape(padding)[0] seq_lens = tf.reshape(seq_lens, [batch_size], name='seq_lens') return seq_lens
def GetEncoderEmbeddingsDefaultTheta(self, input_ids, task_ids=None): p = self.params seq_len = tf.shape(input_ids)[0] # [seq_len, batch, model_dim] input_embs = self.src_token_emb.EmbLookup(self.theta.src_token_emb, input_ids) # [seq_len, 1, model_dim] pos_embs = tf.expand_dims( self.src_pos_emb.FProp(self.theta.src_pos_emb, seq_len), 1) input_embs += pos_embs if task_ids is not None and p.enc_task_emb: input_embs += self.src_task_emb.EmbLookup(self.theta.src_task_emb, task_ids) input_embs = self.src_dropout.FProp(self.theta.src_dropout, input_embs) return input_embs
def _IsCounterClockwiseDirection(v1, v2, v3): """Checks if the path from v1 to v3 via v2 is counter-clockwise. When v1 is equal to v2, or v2 equals v3, return true, by fiat. Tis will work when the v's are padded vectors. Args: v1: a float Tensor of shape [..., 2], indicating the starting point. v2: a Tensor of same type and shape as v1, indicating the via point. v3: a Tensor of same type and shape as v1, indicating the ending point. Returns: True for all directions such that v1 to v3 via v2 is a counter clockwise direction. """ # Check if it's on the left hand side, strictly, and without broadcasting. v1 = py_utils.HasShape(v1, tf.shape(v2)) v1 = py_utils.HasShape(v1, tf.shape(v3)) v1_x, v1_y = v1[..., 0], v1[..., 1] v2_x, v2_y = v2[..., 0], v2[..., 1] v3_x, v3_y = v3[..., 0], v3[..., 1] d1 = (v3_y - v1_y) * (v2_x - v1_x) d2 = (v3_x - v1_x) * (v2_y - v1_y) return d1 >= d2
def testPreconditioning(self): preconditioner_compute_graphdef = self.inverse_pth_root_graph() with tf.Session() as sess: global_step = tf.train.get_or_create_global_step() tf.global_variables_initializer().run() rand_input_1_t = np.random.rand(4, 4) rand_input_2_t = np.random.rand(4, 4) exponents = [-0.25, -0.25] symmetric_input_1_t = np.dot(rand_input_1_t, rand_input_1_t.transpose()) symmetric_input_2_t = np.dot(rand_input_2_t, rand_input_2_t.transpose()) outputs, statuses = ops.get_preconditioners( [tf.shape(symmetric_input_1_t), tf.shape(symmetric_input_2_t)], keys=['a', 'b'], preconditioner_compute_graphdef=preconditioner_compute_graphdef) self.assertFalse(any(sess.run(statuses))) preconditioner = ops.compute_preconditioners( [symmetric_input_1_t, symmetric_input_2_t], exponents, tf.cast(global_step, tf.int32), keys=['a', 'b'], sync=True, preconditioner_compute_graphdef=preconditioner_compute_graphdef) self.assertAllClose(outputs[0].eval(), np.zeros((4, 4)), atol=1e-4) self.assertAllClose(outputs[1].eval(), np.zeros((4, 4)), atol=1e-4) preconditioner.run() self.assertTrue(any(sess.run(statuses))) expected_output_1_t = self.inverse_pth_root(symmetric_input_1_t, exponents[0]) expected_output_2_t = self.inverse_pth_root(symmetric_input_2_t, exponents[1]) outputs_np = sess.run(outputs) self.assertAllClose( outputs_np[0], expected_output_1_t[0].eval(), atol=1e-1) self.assertAllClose( outputs_np[1], expected_output_2_t[0].eval(), atol=1e-1)
def KnnIndices(points, query_points, k, valid_num=None, max_distance=None): """k-nearest neighbors of query_points in points. The caller should ensure that points[i, :valid_num[i], :] are the non-padding points. Padding is returned alongside indices. Non-padded points are guaranteed to be unique (non-repeated) points from original non-padded points. Padded points arise due to either a lack of points (k exceeds valid_num) or points are too far away (exceeds max distance). TODO(weihan,jngiam): For backwards compatibility with PointCNN, if there are fewer than k points to select (possibly because of valid_num), the points selected will first consist of those in the non-padded points, and then those from the padded points. This assumes that the padded points are duplications of the original points. PointCNN should be updated to respect padding. The auxiliary input 'valid_num' marks the number of non-padding points in each sample. This is needed because we randomly duplicated points to make the input fix-sized, we want search for k-NN in non-padding points first otherwise the result may degenerate to be k-duplications of the query point itself. Args: points: tensor of shape [N, P1, dims]. query_points: tensor of shape [N, P2, dims] k: Integer. valid_num: tensor of shape [N,] max_distance: float representing the maximum distance that each neighbor can be. If there are no points within the distance, then the closest point is returned (regardless of distance). If this is set to None, then max_distance is not used. Returns: A pair of tensors: - indices: tensor of shape [N, P2, k]. - padding: tensor of shape [N, P2 ,k] where 1 represents a padded point, and 0 represents an unpadded (real) point. """ p1 = tf.shape(points)[1] padding = None if valid_num is not None: padding = tf.greater_equal(tf.range(p1), tf.expand_dims( valid_num, -1)) # [N, P1], False/True padding return NeighborhoodIndices(points, query_points, k, padding, max_distance)
def FPropTower(self, theta, input_batch): p = self.params batch_size = tf.shape(input_batch.ids)[0] transposed_input_batch = self._TrimIfPossibleThenTranspose(input_batch) labels_ids = transposed_input_batch.labels weights = transposed_input_batch.weights state0 = self.lm.zero_state(theta.lm, batch_size) labels = py_utils.NestedMap(class_ids=labels_ids, class_weights=weights) fprop_kwargs = dict() if p.packed_input: # segment_id for FRNN should be of shape [time, batch, 1]. fprop_kwargs.update(segment_id=tf.expand_dims( transposed_input_batch.segment_ids, -1)) xent_output, _ = self.lm.FProp(theta.lm, transposed_input_batch.ids, transposed_input_batch.paddings, state0, labels, **fprop_kwargs) # +1 to account for the end of sequence symbol. if p.packed_input: num_sentences = input_batch.num_sentences else: num_sentences = tf.constant(1, dtype=tf.int32) num_words = tf.cast( # words and eos tokens. tf.reduce_sum(input_batch.word_count + num_sentences), tf.float32) predicted_labels = tf.cast(xent_output.per_example_argmax, labels_ids.dtype) num_preds = xent_output.total_weight mean_acc = tf.reduce_sum( tf.cast(tf.equal(labels_ids, predicted_labels), tf.float32) * weights) / (num_preds + 1e-4) loss = xent_output.avg_xent if p.train.sum_loss_across_tokens_in_batch: loss = xent_output.total_xent return { 'loss': (loss, num_preds), 'fraction_of_correct_next_step_preds': (mean_acc, num_preds), 'log_pplx': (xent_output.avg_xent, num_preds), 'log_pplx_per_word': (xent_output.total_xent / num_words, num_words), 'num_predictions': (num_preds, 1), 'num_words': (num_words, 1) }, {}
def GetEmbeddings(self, emb_theta, emb, pos_emb_theta, pos_emb, dropout_theta, dropout, input_ids, input_pos_ids): p = self.params seq_len = tf.shape(input_ids)[0] # [seq_len, batch, model_dim] input_embs = emb.EmbLookup(emb_theta, input_ids) if p.packed_input: # Packed inputs. # [seq_len, batch, dim] or [batch, dim] in case of beam search. pos_embs = pos_emb.FPropWithPosition(pos_emb_theta, input_pos_ids) else: # [seq_len, 1, model_dim] pos_embs = tf.expand_dims(pos_emb.FProp(pos_emb_theta, seq_len), 1) input_embs += pos_embs input_embs = dropout.FProp(dropout_theta, input_embs) return input_embs
def assign_preconditioner_to_host_vars(self): """Assign/Grab latest copy of preconditioners.""" keys_shapes_and_preconditioner_vars = [] assign_ops = [] for var in self._all_vars_for_preconditioning: shape = var.get_shape() if not self._fallback_to_diagonal_for_shape(shape): partitioned_v = TensorPartitioner.partition_tensor( var, self._partition_info) num_partitions = len(partitioned_v) for pt_idx, pt in enumerate(partitioned_v): pt_shape = pt.get_shape() preconditioner_exists_for_dim = ( self._preconditioner_available_for_dims(pt_shape)) var_rank = len(pt_shape) for i in range(var_rank): if preconditioner_exists_for_dim[i]: key = self._key_for_var(var, i, pt_idx) preconditioner = self.get_slot( var, self._preconditioner_key_for_partition_and_dim( i, pt_idx, num_partitions)) keys_shapes_and_preconditioner_vars.append( (key, tf.shape(preconditioner), preconditioner)) if not keys_shapes_and_preconditioner_vars: return tf.no_op() keys, shapes, preconditioner_vars = zip( *keys_shapes_and_preconditioner_vars) preconditioner_vals, successes = x_ops.get_preconditioners( shapes, keys=keys, preconditioner_compute_graphdef=( self._preconditioner_compute_graphdef)) for preconditioner_var, preconditioner_val, success in zip( preconditioner_vars, preconditioner_vals, successes): success_mult = tf.cast(success, preconditioner.dtype) assign_ops.append( state_ops.assign( preconditioner_var, (1.0 - success_mult) * preconditioner_var + success_mult * preconditioner_val)) return tf.group(*assign_ops)
def FPropTower(self, theta, input_batch): p = self.params fprop_dtype = py_utils.FPropDtype(p) tf.logging.info('input_batch=%r', input_batch) ids = input_batch.ids labels_ids = input_batch.labels paddings = tf.cast(input_batch.paddings, fprop_dtype) weights = tf.cast(input_batch.weights, fprop_dtype) tf.logging.info('inputs={}'.format( (ids, paddings, labels_ids, weights))) batch_size = tf.shape(ids)[0] state0 = self.lm.zero_state(theta.lm, batch_size) labels = py_utils.NestedMap(class_ids=labels_ids, class_weights=weights) xent_output, _ = self.lm.FProp(theta.lm, ids, paddings, state0, labels, segment_ids=input_batch.segment_ids, segment_pos=input_batch.segment_pos) # +input_batch.num_sentences to account for the end of sequence symbol. num_words = tf.cast( tf.reduce_sum(input_batch.word_count + tf.cast(input_batch.num_sentences, dtype=tf.int32)), fprop_dtype) predicted_labels = tf.cast(xent_output.per_example_argmax, labels_ids.dtype) num_sentences = tf.reduce_sum(input_batch.num_sentences) num_preds = xent_output.total_weight mean_acc = tf.reduce_sum( tf.cast(tf.equal(labels_ids, predicted_labels), fprop_dtype) * weights) / tf.math.maximum(num_preds, 1) loss = xent_output.avg_xent return { 'loss': (loss, num_preds), 'fraction_of_correct_next_step_preds': (mean_acc, num_preds), 'log_pplx': (xent_output.avg_xent, num_preds), 'log_pplx_per_word': (xent_output.total_xent / num_words, num_words), 'num_predictions': (num_preds, 1), 'num_words': (num_words, 1), 'num_sentences': (num_sentences, 1) }, {}
def _Extract(self, features): p = self.params # Label values match the proto enum car.open_dataset.Label.Type. The value # range is [1..4] for non-background labels. labels = tf.to_int32(_Dense(features['labels'])) labels = py_utils.PadOrTrimTo(labels, [p.max_num_objects]) label_ids = tf.reshape(_Dense(features['label_ids'], ''), [-1]) label_ids = py_utils.PadOrTrimTo(label_ids, [p.max_num_objects], '') bboxes_3d = tf.reshape(_Dense(features['bboxes_3d']), [-1, 7]) bboxes_3d_mask = tf.ones([tf.shape(bboxes_3d)[0]]) bboxes_3d_num_points = tf.to_int32(_Dense(features['bboxes_3d_num_points'])) bboxes_3d = py_utils.PadOrTrimTo(bboxes_3d, [p.max_num_objects, 7]) bboxes_3d_mask = py_utils.PadOrTrimTo(bboxes_3d_mask, [p.max_num_objects]) bboxes_3d_num_points = py_utils.PadOrTrimTo(bboxes_3d_num_points, [p.max_num_objects]) label_metadata = tf.reshape(_Dense(features['label_metadata']), [-1, 4]) label_metadata = py_utils.PadOrTrimTo(label_metadata, [p.max_num_objects, 4]) detection_difficulties = py_utils.PadOrTrimTo( tf.to_int32(_Dense(features['detection_difficulties'])), [p.max_num_objects]) tracking_difficulties = py_utils.PadOrTrimTo( tf.to_int32(_Dense(features['tracking_difficulties'])), [p.max_num_objects]) unfiltered_bboxes_3d_mask = bboxes_3d_mask if p.filter_labels: valid_labels = tf.constant([p.filter_labels]) bbox_mask = tf.reduce_any( tf.equal(tf.expand_dims(labels, 1), valid_labels), axis=1) bboxes_3d_mask *= tf.to_float(bbox_mask) outputs = { 'labels': labels, 'label_ids': label_ids, 'detection_difficulties': detection_difficulties, 'tracking_difficulties': tracking_difficulties, 'bboxes_3d': bboxes_3d, 'bboxes_3d_mask': bboxes_3d_mask, 'bboxes_3d_num_points': bboxes_3d_num_points, 'unfiltered_bboxes_3d_mask': unfiltered_bboxes_3d_mask, 'speed': label_metadata[:, :2], 'acceleration': label_metadata[:, 2:], } return py_utils.NestedMap(outputs)
def _FPropLm(self, theta, state0, ids, paddings, misc=None): """LM FProp. Works for single step or entire seq. Args: theta: A NestedMap object containing weights for the layer and its children. state0: A NestedMap of states (specific to the layer). ids: Target ids, of shape [batch_size] for single step unrolling or [seq_len, batch_size] for the entire sequence. paddings: Target paddings, of the same shape as 'ids'. misc: NestedMap of miscellaneous items, which might be needed during training. Returns: (lm_output, state1): - lm_output: A NestedMap containing lm output. If 'ids' is 1-D, then lm_output should have shape [batch_size, dim]; if it is 2-D then the shape should be [seq_len, batch_size, dim]. - state1: A NestedMap of updated states. """ state1 = state0.DeepCopy() if isinstance(ids.shape, tf.TensorShape): is_single_step = (ids.shape.rank == 1) else: is_single_step = len(ids.shape) == 1 if is_single_step: seq_len = 1 else: seq_len = tf.shape(ids)[0] self._ModifyLmBeforeFProp(theta, state0, ids, paddings, misc) with tf.name_scope('lm'): ids = tf.reshape(ids, [seq_len, -1], name='reshape_ids') paddings = tf.reshape(paddings, [seq_len, -1], name='reshape_paddings') lm_output, state1.lm_states = self.lm.FProp(theta.lm, ids, paddings, state0.lm_states) if is_single_step: # lm outputs have dimension [time, batch, dim]. Since this is only one # step, remove time dimension. lm_output = lm_output.Transform(lambda v: tf.squeeze(v, axis=0)) return lm_output, state1
def GetDecoderEmbeddingsDefaultTheta(self, input_ids, t=None): p = self.params seq_len = tf.shape(input_ids)[0] # [seq_len, batch, model_dim] input_embs = self.tgt_token_emb.EmbLookup(self.theta.tgt_token_emb, input_ids) # [seq_len, 1, model_dim] if t is None: pos_embs = tf.expand_dims( self.tgt_pos_emb.FProp(self.theta.tgt_pos_emb, seq_len), 1) else: # Support decoding. pos_embs = tf.slice( self.tgt_pos_emb.FProp(self.theta.tgt_pos_emb, p.max_seq_len), [t, 0], [1, p.token_emb.embedding_dim]) input_embs += pos_embs input_embs = self.tgt_dropout.FProp(self.theta.tgt_dropout, input_embs) return input_embs
def FProp(self, theta, inputs, paddings, state0, labels=None, direct_features=None): """Computes xent loss given the language model input activations. Args: theta: A `.NestedMap` object containing weights' values of this layer and its children layers. inputs: input ids. An int32 tensor of shape [time, batch]. paddings: a 0/1 tensor of shape [time, batch]. state0: A `.NestedMap` containing the initial recurrent state. labels: If not None, a `.NestedMap` containing the following fields: - class_weights, a tensor with shape [time, batch] containing the weights for each target word. - class_ids, a tensor with shape [time, batch] of int32 dtype containing the target class labels. - class_probabilities, a tensor with shape [time, batch, vocab_size] of float values indicating class-membership probabilities. direct_features: If not None, a tensor of [time, batch, direct_feature_dims] that is concatenated to the output of the last RNN layer. Returns: If `labels` is not None, returns (xent_output, state1), where `xent_output` is a `.NestedMap` as defined by `SoftmaxLayer`'s return value and `state1` is the next recurrent state. Otherwise, `xent_output` only contains the softmax logits. """ ids = py_utils.HasRank(inputs, 2) paddings = py_utils.HasShape(paddings, tf.shape(ids)) assert state0 activation = self.emb.EmbLookup(theta.emb, ids) # Dropout on embeddings is only applied in training. p = self.params if p.embedding_dropout_keep_prob < 1.0 and not p.is_eval: activation = tf.nn.dropout( activation, keep_prob=p.embedding_dropout_keep_prob, seed=p.embedding_dropout_seed) return super(RnnLm, self).FProp(theta, activation, paddings, state0, labels, direct_features)
def GetEmbeddings(self, emb_theta, emb, pos_emb_theta, pos_emb, dropout_theta, dropout, input_ids, input_pos_ids, task_emb_theta, task_emb, task_ids): p = self.params time_dim = 0 if p.batch_dim else 1 seq_len = tf.shape(input_ids)[time_dim] input_embs = emb.EmbLookup(emb_theta, input_ids) if p.packed_input: # Packed inputs. pos_embs = pos_emb.FPropWithPosition(pos_emb_theta, input_pos_ids) else: pos_embs = tf.expand_dims(pos_emb.FProp(pos_emb_theta, seq_len), p.batch_dim) input_embs += pos_embs if task_emb: input_embs += task_emb.EmbLookup(task_emb_theta, task_ids) input_embs = dropout.FProp(dropout_theta, input_embs) return input_embs