def __call__(self, shape, dtype=None, partition_info=None): if dtype is None: dtype = self.dtype # Check the shape if len(shape) < 3 or len(shape) > 5: raise ValueError("The tensor to initialize must be at least " "three-dimensional and at most five-dimensional") if shape[-2] > shape[-1]: raise ValueError("In_filters cannot be greater than out_filters.") # Generate a random matrix a = random_ops.random_normal([shape[-1], shape[-1]], dtype=dtype, seed=self.seed) # Compute the qr factorization q, r = linalg_ops.qr(a, full_matrices=False) # Make Q uniform d = array_ops.diag_part(r) q *= math_ops.sign(d) q = q[:shape[-2], :] q *= math_ops.sqrt(math_ops.cast(self.gain, dtype=dtype)) if len(shape) == 3: weight = array_ops.scatter_nd([[(shape[0]-1)//2]], array_ops.expand_dims(q, 0), shape) elif len(shape) == 4: weight = array_ops.scatter_nd([[(shape[0]-1)//2, (shape[1]-1)//2]], array_ops.expand_dims(q, 0), shape) else: weight = array_ops.scatter_nd([[(shape[0]-1)//2, (shape[1]-1)//2, (shape[2]-1)//2]], array_ops.expand_dims(q, 0), shape) return weight
def testInvalidShape(self): # TODO(apassos) figure out how to unify these errors with self.assertRaises(errors.InvalidArgumentError if context.executing_eagerly() else ValueError): array_ops.scatter_nd(indices=[0], # this should be indices=[[0]] updates=[0.0], shape=[1])
def testEmptyOutputShape1(self): indices = array_ops.zeros([2, 2, 2], dtypes.int32) updates = array_ops.zeros([2, 2, 2], dtypes.int32) shape = constant_op.constant([0, 3, 2], dtypes.int32) with self.assertRaisesWithPredicateMatch( ValueError, "Indices and updates specified for empty output shape"): array_ops.scatter_nd(indices, updates, shape)
def _ctc_state_trans(label_seq): """Compute CTC alignment model transition matrix. Args: label_seq: tensor of shape [batch_size, max_seq_length] Returns: tensor of shape [batch_size, states, states] with a state transition matrix computed for each sequence of the batch. """ with ops.name_scope("ctc_state_trans"): label_seq = ops.convert_to_tensor(label_seq, name="label_seq") batch_size = _get_dim(label_seq, 0) num_labels = _get_dim(label_seq, 1) num_label_states = num_labels + 1 num_states = 2 * num_label_states label_states = math_ops.range(num_label_states) blank_states = label_states + num_label_states # Start state to first label. start_to_label = [[1, 0]] # Blank to label transitions. blank_to_label = array_ops.stack([label_states[1:], blank_states[:-1]], 1) # Label to blank transitions. label_to_blank = array_ops.stack([blank_states, label_states], 1) # Scatter transitions that don't depend on sequence. indices = array_ops.concat( [start_to_label, blank_to_label, label_to_blank], 0) values = array_ops.ones([_get_dim(indices, 0)]) trans = array_ops.scatter_nd( indices, values, shape=[num_states, num_states]) trans += linalg_ops.eye(num_states) # Self-loops. # Label to label transitions. Disallow transitions between repeated labels # with no blank state in between. batch_idx = array_ops.zeros_like(label_states[2:]) indices = array_ops.stack( [batch_idx, label_states[2:], label_states[1:-1]], 1) indices = array_ops.tile( array_ops.expand_dims(indices, 0), [batch_size, 1, 1]) batch_idx = array_ops.expand_dims(math_ops.range(batch_size), 1) * [1, 0, 0] indices += array_ops.expand_dims(batch_idx, 1) repeats = math_ops.equal(label_seq[:, :-1], label_seq[:, 1:]) values = 1.0 - math_ops.cast(repeats, dtypes.float32) batched_shape = [batch_size, num_states, num_states] label_to_label = array_ops.scatter_nd(indices, values, batched_shape) return array_ops.expand_dims(trans, 0) + label_to_label
def testEmptyOutputShape2(self): indices = array_ops.placeholder(dtypes.int32, shape=None) updates = array_ops.placeholder(dtypes.int32, shape=None) shape = constant_op.constant([0, 3, 2], dtypes.int32) with self.test_session(): array_ops.scatter_nd(indices, updates, shape).eval(feed_dict={ indices: np.zeros( [2, 2, 2], dtype=np.int32), updates: np.zeros( [2, 2, 2], dtype=np.int32) })
def testRank3InvalidShape2(self): indices = array_ops.zeros([2, 2, 1], dtypes.int32) updates = array_ops.zeros([2, 2], dtypes.int32) shape = np.array([2, 2, 2]) with self.assertRaisesWithPredicateMatch( ValueError, "The inner \\d+ dimensions of output\\.shape="): array_ops.scatter_nd(indices, updates, shape) ref = variables.Variable(array_ops.zeros(shape, dtypes.int32)) with self.assertRaisesWithPredicateMatch( ValueError, "The inner \\d+ dimensions of ref\\.shape="): state_ops.scatter_nd_update(ref, indices, updates)
def collapse_repeated(labels, seq_length, name=None): """Merge repeated labels into single labels. Args: labels: Tensor of shape [batch, max value in seq_length] seq_length: Tensor of shape [batch], sequence length of each batch element. name: A name for this `Op`. Defaults to "collapse_repeated_labels". Returns: A tuple `(collapsed_labels, new_seq_length)` where collapsed_labels: Tensor of shape [batch, max_seq_length] with repeated labels collapsed and padded to max_seq_length, eg: `[[A, A, B, B, A], [A, B, C, D, E]] => [[A, B, A, 0, 0], [A, B, C, D, E]]` new_seq_length: int tensor of shape [batch] with new sequence lengths. """ with ops.name_scope(name, "collapse_repeated_labels", [labels, seq_length]): labels = ops.convert_to_tensor(labels, name="labels") seq_length = ops.convert_to_tensor(seq_length, name="seq_length") # Mask labels that don't equal previous label. label_mask = array_ops.concat([ array_ops.ones_like(labels[:, :1], dtypes.bool), math_ops.not_equal(labels[:, 1:], labels[:, :-1]) ], axis=1) # Filter labels that aren't in the original sequence. maxlen = _get_dim(labels, 1) seq_mask = array_ops.sequence_mask(seq_length, maxlen=maxlen) label_mask = math_ops.logical_and(label_mask, seq_mask) # Count masks for new sequence lengths. new_seq_len = math_ops.reduce_sum( math_ops.cast(label_mask, dtypes.int32), axis=1) # Mask indexes based on sequence length mask. new_maxlen = math_ops.reduce_max(new_seq_len) idx_mask = array_ops.sequence_mask(new_seq_len, maxlen=new_maxlen) # Flatten everything and mask out labels to keep and sparse indices. flat_labels = array_ops.reshape(labels, [-1]) flat_label_mask = array_ops.reshape(label_mask, [-1]) flat_idx_mask = array_ops.reshape(idx_mask, [-1]) idx = math_ops.range(_get_dim(flat_idx_mask, 0)) # Scatter to flat shape. flat = array_ops.scatter_nd( indices=array_ops.expand_dims( array_ops.boolean_mask(idx, flat_idx_mask), axis=1), updates=array_ops.boolean_mask(flat_labels, flat_label_mask), shape=array_ops.shape(flat_idx_mask)) # Reshape back to square batch. batch_size = _get_dim(labels, 0) new_shape = [batch_size, new_maxlen] return (array_ops.reshape(flat, new_shape), math_ops.cast(new_seq_len, seq_length.dtype))
def testScatterNdRepatedIndicesAdd(self): indices = array_ops.zeros([100000, 1], dtypes.int32) values = np.random.randn(100000) shape = [1] with self.test_session(): val = array_ops.scatter_nd(indices, values, shape).eval() self.assertAllClose([np.sum(values)], val)
def _state_to_olabel_unique(labels, num_labels, states, unique): """Sum state log probs to ilabel log probs using unique label indices.""" num_label_states = _get_dim(labels, 1) + 1 label_states = states[:, :, 1:num_label_states] blank_states = states[:, :, num_label_states:] unique_y, unique_idx = unique mul_reduce = _sum_states(unique_idx, label_states) num_frames = states.shape[0] batch_size = states.shape[1] num_states = num_label_states - 1 batch_state_major = array_ops.transpose(mul_reduce, perm=[1, 2, 0]) batch_state_major = array_ops.reshape( batch_state_major, [batch_size * num_states, num_frames]) batch_offset = math_ops.range(batch_size, dtype=unique_y.dtype) * num_labels indices = unique_y + array_ops.expand_dims(batch_offset, axis=-1) indices = array_ops.reshape(indices, [-1, 1]) scatter = array_ops.scatter_nd( indices=indices, updates=batch_state_major, shape=[batch_size * num_labels, num_frames]) scatter = array_ops.reshape(scatter, [batch_size, num_labels, num_frames]) scatter = array_ops.where( math_ops.equal(scatter, 0.0), array_ops.fill(array_ops.shape(scatter), math_ops.log(0.0)), scatter) label_olabels = array_ops.transpose(scatter, [2, 0, 1]) label_olabels = label_olabels[:, :, 1:] blank_olabels = math_ops.reduce_logsumexp( blank_states, axis=2, keepdims=True) return array_ops.concat([blank_olabels, label_olabels], axis=-1)
def testEmptyOutputShape3(self): indices = array_ops.zeros([0], dtypes.int32) updates = array_ops.zeros([0], dtypes.int32) shape = constant_op.constant([0], dtypes.int32) scatter = array_ops.scatter_nd(indices, updates, shape) with self.test_session(): self.assertEqual(scatter.eval().size, 0)
def maybe_sample(): """Perform scheduled sampling.""" where_sampling = math_ops.cast( array_ops.where(sample_ids > -1), dtypes.int32) where_not_sampling = math_ops.cast( array_ops.where(sample_ids <= -1), dtypes.int32) sample_ids_sampling = array_ops.gather_nd(sample_ids, where_sampling) inputs_not_sampling = array_ops.gather_nd( base_next_inputs, where_not_sampling) sampled_next_inputs = self._embedding_fn(sample_ids_sampling) base_shape = array_ops.shape(base_next_inputs) return (array_ops.scatter_nd(indices=where_sampling, updates=sampled_next_inputs, shape=base_shape) + array_ops.scatter_nd(indices=where_not_sampling, updates=inputs_not_sampling, shape=base_shape))
def _apply_sparse_shared(self, grad_values, grad_indices, var): if var.get_shape()[0] <= self._max_matrix_size or self._gbar_decay != 0.0: # The dimension is small enough, we can make the variable dense and # do a dense update dense_grad = array_ops.scatter_nd( array_ops.expand_dims(grad_indices, axis=1), grad_values, array_ops.shape(var, out_type=grad_indices.dtype)) return self._apply_gradient(dense_grad, var) return self._apply_gradient(grad_values, var, grad_indices)
def _runScatterNd(self, indices, updates, shape): with self.test_session(): updates_placeholder = array_ops.placeholder(updates.dtype) indices_placeholder = array_ops.placeholder(indices.dtype) with self.test_scope(): output = array_ops.scatter_nd(indices_placeholder, updates_placeholder, shape) feed_dict = {updates_placeholder: updates, indices_placeholder: indices} return output.eval(feed_dict=feed_dict)
def _GatherNdGrad(op, grad): ref = op.inputs[0] indices = op.inputs[1] ref_shape = array_ops.shape(ref, out_type=indices.dtype) if indices.shape.ndims == 2 and indices.shape.dims[-1].value == 1: ref_grad = ops.IndexedSlices(grad, array_ops.squeeze(indices, axis=-1), ref_shape) else: ref_grad = array_ops.scatter_nd(indices, grad, ref_shape) return [ref_grad, None]
def maybe_sample(): """Perform scheduled sampling.""" if self._next_input_layer is None: return array_ops.where(sample_ids, outputs, base_next_inputs) where_sampling = math_ops.cast( array_ops.where(sample_ids), dtypes.int32) where_not_sampling = math_ops.cast( array_ops.where(math_ops.logical_not(sample_ids)), dtypes.int32) outputs_sampling = array_ops.gather_nd(outputs, where_sampling) inputs_not_sampling = array_ops.gather_nd(base_next_inputs, where_not_sampling) sampled_next_inputs = self._next_input_layer(outputs_sampling) base_shape = array_ops.shape(base_next_inputs) return (array_ops.scatter_nd(indices=where_sampling, updates=sampled_next_inputs, shape=base_shape) + array_ops.scatter_nd(indices=where_not_sampling, updates=inputs_not_sampling, shape=base_shape))
def testGradientsRank2SliceUpdate(self): indices = constant_op.constant([[1], [0]], dtype=dtypes.int32) updates = constant_op.constant([[3, 4], [1, 2]], dtype=dtypes.float64) shape = constant_op.constant([2, 2], dtype=dtypes.int32) outputs = array_ops.scatter_nd(indices, updates, shape) grad_vals = constant_op.constant([[3, 4], [1, 2]], dtype=dtypes.float64) grads = gradients_impl.gradients([outputs], [updates], [grad_vals])[0] expected_grads = np.array([[1, 2], [3, 4]], dtype=np.float64) with self.test_session(): self.assertAllEqual(expected_grads, grads.eval())
def testRank3ValidShape(self): indices = array_ops.zeros([2, 2, 2], dtypes.int32) updates = array_ops.zeros([2, 2, 2], dtypes.int32) shape = np.array([2, 2, 2]) self.assertAllEqual( array_ops.scatter_nd(indices, updates, shape).get_shape().as_list(), shape) ref = variables.Variable(array_ops.zeros(shape, dtypes.int32)) self.assertAllEqual( state_ops.scatter_nd_update(ref, indices, updates).get_shape().as_list(), shape)
def scheduled_sampling(self, batch_size, sampling_probability, true, estimate): with variable_scope.variable_scope("ScheduledEmbedding"): # Return -1s where we do not sample, and sample_ids elsewhere select_sampler = bernoulli.Bernoulli(probs=sampling_probability, dtype=tf.bool) select_sample = select_sampler.sample(sample_shape=batch_size) sample_ids = array_ops.where( select_sample, tf.range(batch_size), gen_array_ops.fill([batch_size], -1)) where_sampling = math_ops.cast( array_ops.where(sample_ids > -1), tf.int32) where_not_sampling = math_ops.cast( array_ops.where(sample_ids <= -1), tf.int32) _estimate = array_ops.gather_nd(estimate, where_sampling) _true = array_ops.gather_nd(true, where_not_sampling) base_shape = array_ops.shape(true) result1 = array_ops.scatter_nd(indices=where_sampling, updates=_estimate, shape=base_shape) result2 = array_ops.scatter_nd(indices=where_not_sampling, updates=_true, shape=base_shape) result = result1 + result2 return result1 + result2
def maybe_sample(): """Perform scheduled sampling.""" def maybe_concatenate_auxiliary_inputs(outputs_, indices=None): """Concatenate outputs with auxiliary inputs, if they exist.""" if self._auxiliary_input_tas is None: return outputs_ next_time = time + 1 auxiliary_inputs = nest.map_structure( lambda ta: ta.read(next_time), self._auxiliary_input_tas) if indices is not None: auxiliary_inputs = array_ops.gather_nd(auxiliary_inputs, indices) return nest.map_structure( lambda x, y: array_ops.concat((x, y), -1), outputs_, auxiliary_inputs) if self._next_input_layer is None: return array_ops.where( sample_ids, maybe_concatenate_auxiliary_inputs(outputs), base_next_inputs) where_sampling = math_ops.cast( array_ops.where(sample_ids), dtypes.int32) where_not_sampling = math_ops.cast( array_ops.where(math_ops.logical_not(sample_ids)), dtypes.int32) outputs_sampling = array_ops.gather_nd(outputs, where_sampling) inputs_not_sampling = array_ops.gather_nd(base_next_inputs, where_not_sampling) sampled_next_inputs = maybe_concatenate_auxiliary_inputs( self._next_input_layer(outputs_sampling), where_sampling) base_shape = array_ops.shape(base_next_inputs) return (array_ops.scatter_nd(indices=where_sampling, updates=sampled_next_inputs, shape=base_shape) + array_ops.scatter_nd(indices=where_not_sampling, updates=inputs_not_sampling, shape=base_shape))
def testExtraIndicesDimensions(self): indices = array_ops.zeros([1, 1, 2], dtypes.int32) updates = array_ops.zeros([1, 1], dtypes.int32) shape = np.array([2, 2]) scatter = array_ops.scatter_nd(indices, updates, shape) self.assertAllEqual(scatter.get_shape().as_list(), shape) expected_result = np.zeros([2, 2], dtype=np.int32) with self.test_session(): self.assertAllEqual(expected_result, scatter.eval()) ref = variables.Variable(array_ops.zeros(shape, dtypes.int32)) scatter_update = state_ops.scatter_nd_update(ref, indices, updates) self.assertAllEqual(scatter_update.get_shape().as_list(), shape) with self.test_session(): ref.initializer.run() self.assertAllEqual(expected_result, scatter_update.eval())
def _TopKGrad(op, grad, _): """Return the gradients for TopK. Args: op: The TopKOp for which we need to generate gradients. grad: Tensor. The gradients passed to the TopKOp. Returns: A list of two tensors, the first being the gradient w.r.t to the input and TopK, and the second being the gradient w.r.t. to the indices (all zero). """ in_shape = array_ops.shape(op.inputs[0]) ind_shape = array_ops.shape(op.outputs[1]) # int32 is not supported on GPU hence up-casting ind_lastdim = array_ops.gather( math_ops.cast(ind_shape, dtypes.int64), array_ops.size(ind_shape) - 1) # Flatten indices to 2D. ind_2d = array_ops.reshape(op.outputs[1], array_ops.stack([-1, ind_lastdim])) in_lastdim = array_ops.gather( math_ops.cast(in_shape, dtypes.int64), array_ops.size(in_shape) - 1) outerdim = array_ops.shape(ind_2d)[0] # Compute linear indices (flattened to 1D). ind = array_ops.reshape( ind_2d + math_ops.cast( array_ops.expand_dims( math_ops.range(0, math_ops.cast(outerdim, dtypes.int64) * in_lastdim, in_lastdim), -1), dtypes.int32), [-1]) # Substitute grad to appropriate locations and fill the rest with zeros, # finally reshaping it to the original input shape. return [ array_ops.reshape( array_ops.scatter_nd( array_ops.expand_dims(ind, -1), array_ops.reshape(grad, [-1]), [math_ops.reduce_prod(in_shape)]), in_shape), array_ops.zeros([], dtype=dtypes.int32) ]
def ScatterUpdateGrads(op, grad): var, indices, updates = op.inputs updates_grad = array_ops.gather(grad, indices) # dynamic stitch approach (this seems to be a bit slower) # grad_range = math_ops.range(grad.get_shape()[0].value) # var_grad = data_flow_ops.dynamic_stitch( # [grad_range, indices], # [grad, array_ops.zeros(updates.get_shape())]) if isinstance(grad, ops.IndexedSlices): # note: we could use this approach for everything, but the # temporary variable approach seems to be slightly faster (but we # can't use that on indexedslices) var_grad = grad - array_ops.scatter_nd( array_ops.expand_dims(indices, 1), updates_grad, var.get_shape()) else: shape = tuple(grad.get_shape().as_list()) dtype = grad.dtype.base_dtype with variable_scope.variable_scope( "gradient_vars", reuse=variable_scope.AUTO_REUSE): var_grad = variable_scope.get_variable( "tmp" + "_%s" * (len(grad.get_shape()) + 1) % ( shape + (dtype.name,)), shape=shape, dtype=dtype, trainable=False, collections=["gradient_vars"]) var_grad = state_ops.assign(var_grad, grad) var_grad = state_ops.scatter_update( var_grad, indices, array_ops.zeros_like(updates)) # we need to force a copy so that any future assignments to the # variable will not affect the value we return here # TODO: check if this is still necessary in TensorFlow 2.0 var_grad = var_grad + 0 return var_grad, None, updates_grad
def testUndefinedOutputShape(self): indices = array_ops.placeholder(dtypes.int32, shape=[2, 2, 2]) updates = array_ops.placeholder(dtypes.int32, shape=[2, 2, 2]) shape = array_ops.placeholder(dtypes.int32, shape=[None]) array_ops.scatter_nd(indices, updates, shape)
def _GatherNdGrad(op, grad): ref = op.inputs[0] indices = op.inputs[1] ref_shape = array_ops.shape(ref, out_type=indices.dtype) ref_grad = array_ops.scatter_nd(indices, grad, ref_shape) return [ref_grad, None]
def scheduled_sampling_vocab_dist(hps, sampling_probability, output, embedding, inp, alpha=0): # borrowed ideas from https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/ScheduledEmbeddingTrainingHelper def soft_argmax(alpha, output): # alpha_exp = tf.exp(alpha * output) # (batch_size, vocab_size) # one_hot_scores = alpha_exp / tf.reshape(tf.reduce_sum(alpha_exp, axis=1),[-1,1]) #(batch_size, vocab_size) one_hot_scores = tf.nn.softmax(alpha * output) return one_hot_scores def soft_top_k(alpha, output, K): copy = tf.identity(output) p = [] arg_top_k = [] for k in range(K): sargmax = soft_argmax(alpha, copy) copy = (1 - sargmax) * copy p.append(tf.reduce_sum(sargmax * output, axis=1)) arg_top_k.append(sargmax) return tf.stack(p, axis=1), tf.stack(arg_top_k) with variable_scope.variable_scope("ScheduledEmbedding"): # Return -1s where we did not sample, and sample_ids elsewhere select_sampler = bernoulli.Bernoulli(probs=sampling_probability, dtype=tf.bool) select_sample = select_sampler.sample(sample_shape=hps.batch_size) sample_id_sampler = categorical.Categorical( probs=output ) # equals to argmax{ Multinomial(output, total_count=1) }, our greedy search selection sample_ids = array_ops.where(select_sample, sample_id_sampler.sample(seed=123), gen_array_ops.fill([hps.batch_size], -1)) where_sampling = math_ops.cast(array_ops.where(sample_ids > -1), tf.int32) where_not_sampling = math_ops.cast(array_ops.where(sample_ids <= -1), tf.int32) if hps.greedy_scheduled_sampling: sample_ids = tf.argmax(output, axis=1, output_type=tf.int32) sample_ids_sampling = array_ops.gather_nd(sample_ids, where_sampling) inputs_not_sampling = array_ops.gather_nd(inp, where_not_sampling) if hps.E2EBackProp: if hps.hard_argmax: greedy_search_prob, greedy_search_sample = tf.nn.top_k( output, k=hps.k) # (batch_size, k) greedy_search_prob_normalized = greedy_search_prob / tf.reshape( tf.reduce_sum(greedy_search_prob, axis=1), [-1, 1]) greedy_embedding = tf.nn.embedding_lookup( embedding, greedy_search_sample) normalized_embedding = tf.multiply( tf.reshape(greedy_search_prob_normalized, [hps.batch_size, hps.k, 1]), greedy_embedding) e2e_embedding = tf.reduce_mean(normalized_embedding, axis=1) else: e = [] greedy_search_prob, greedy_search_sample = soft_top_k( alpha, output, K=hps.k) # (batch_size, k), (k, batch_size, vocab_size) greedy_search_prob_normalized = greedy_search_prob / tf.reshape( tf.reduce_sum(greedy_search_prob, axis=1), [-1, 1]) for _ in range(hps.k): a_k = greedy_search_sample[_] e_k = tf.matmul( tf.reshape(greedy_search_prob_normalized[:, _], [-1, 1]) * a_k, embedding) e.append(e_k) e2e_embedding = tf.reduce_sum(e, axis=0) # (batch_size, emb_dim) sampled_next_inputs = array_ops.gather_nd(e2e_embedding, where_sampling) else: if hps.hard_argmax: sampled_next_inputs = tf.nn.embedding_lookup( embedding, sample_ids_sampling) else: # using soft armax (greedy) proposed in: https://arxiv.org/abs/1704.06970 # alpha_exp = tf.exp(alpha * (output_not_extended + G)) # (batch_size, vocab_size) # one_hot_scores = alpha_exp / tf.reduce_sum(alpha_exp, axis=1) #(batch_size, vocab_size) one_hot_scores = soft_argmax( alpha, output) # (batch_size, vocab_size) soft_argmax_embedding = tf.matmul( one_hot_scores, embedding) # (batch_size, emb_size) sampled_next_inputs = array_ops.gather_nd( soft_argmax_embedding, where_sampling) base_shape = array_ops.shape(inp) result1 = array_ops.scatter_nd(indices=where_sampling, updates=sampled_next_inputs, shape=base_shape) result2 = array_ops.scatter_nd(indices=where_not_sampling, updates=inputs_not_sampling, shape=base_shape) return result1 + result2
def scatter_nd(self, indices, updates, shape, input_=None): del input_ # input_ is not used in scatter_nd return array_ops.scatter_nd(indices, updates, shape)
def _get_update_op(self, score_drop, score_grow, mask, weights, reinit_when_same=False): """Prunes+grows connections, all tensors same shape.""" old_dtype = mask.dtype mask_casted = math_ops.cast(mask, dtypes.float32) n_total = array_ops.size(score_drop) n_ones = math_ops.cast(math_ops.reduce_sum(mask_casted), dtype=dtypes.int32) n_prune = math_ops.cast( math_ops.cast(n_ones, dtype=dtypes.float32) * self.drop_fraction, dtypes.int32) n_keep = n_ones - n_prune # Sort the entire array since the k needs to be constant for TPU. _, sorted_indices = nn_ops.top_k(array_ops.reshape(score_drop, [-1]), k=n_total) sorted_indices_ex = array_ops.expand_dims(sorted_indices, 1) # We will have zeros after having `n_keep` many ones. new_values = array_ops.where( math_ops.range(n_total) < n_keep, array_ops.ones_like(sorted_indices, dtype=mask_casted.dtype), array_ops.zeros_like(sorted_indices, dtype=mask_casted.dtype)) mask1 = array_ops.scatter_nd(sorted_indices_ex, new_values, new_values.shape) # Flatten the scores score_grow = array_ops.reshape(score_grow, [-1]) # Set scores of the enabled connections(ones) to min(s) - 1, so that they # have the lowest scores. score_grow_lifted = array_ops.where( math_ops.equal(mask1, 1), array_ops.ones_like(mask1) * (math_ops.reduce_min(score_grow) - 1), score_grow) _, sorted_indices = nn_ops.top_k(score_grow_lifted, k=n_total) sorted_indices_ex = array_ops.expand_dims(sorted_indices, 1) new_values = array_ops.where( math_ops.range(n_total) < n_prune, array_ops.ones_like(sorted_indices, dtype=mask_casted.dtype), array_ops.zeros_like(sorted_indices, dtype=mask_casted.dtype)) mask2 = array_ops.scatter_nd(sorted_indices_ex, new_values, new_values.shape) # Ensure masks are disjoint. assert_op = control_flow_ops.Assert( math_ops.equal(math_ops.reduce_sum(mask1 * mask2), 0.), [mask1, mask2]) with ops.control_dependencies([assert_op]): # Let's set the weights of the growed connections. mask2_reshaped = array_ops.reshape(mask2, mask.shape) # Set the values of the new connections. grow_tensor = self.get_grow_tensor(weights, self._grow_init) if reinit_when_same: # If dropped and grown, we re-initialize. new_connections = math_ops.equal(mask2_reshaped, 1) else: new_connections = math_ops.logical_and( math_ops.equal(mask2_reshaped, 1), math_ops.equal(mask_casted, 0)) new_weights = array_ops.where(new_connections, grow_tensor, weights) weights_update = state_ops.assign(weights, new_weights) # Ensure there is no momentum value for new connections reset_op = self.reset_momentum(weights, new_connections) with ops.control_dependencies([weights_update, reset_op]): mask_combined = array_ops.reshape(mask1 + mask2, mask.shape) mask_combined = math_ops.cast(mask_combined, dtype=old_dtype) new_mask = state_ops.assign(mask, mask_combined) return new_mask
def scheduled_sampling(hps, sampling_probability, output, embedding, inp, alpha=0): # borrowed ideas from https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/ScheduledEmbeddingTrainingHelper vocab_size = embedding.get_shape()[0].value def soft_argmax(alpha, output): alpha_exp = tf.exp(alpha * output) # (batch_size, vocab_size) one_hot_scores = alpha_exp / tf.reshape( tf.reduce_sum(alpha_exp, axis=1), [-1, 1]) #(batch_size, vocab_size) return one_hot_scores def soft_top_k(alpha, output, K): copy = tf.identity(output) p = [] arg_top_k = [] for k in range(K): sargmax = soft_argmax(alpha, copy) copy = (1 - sargmax) * copy p.append(tf.reduce_sum(sargmax * output, axis=1)) # replace oov with unk if necessary mask = tf.equal(tf.reduce_max(sargmax, axis=1), tf.reduce_max(sargmax[:, 0:vocab_size], axis=1)) sargmax_truncated = tf.where( mask, sargmax[:, 0:vocab_size], tf.stack([ tf.one_hot(0, vocab_size) for _ in range(hps.batch_size) ])) arg_top_k.append(sargmax_truncated) return p, tf.stack(arg_top_k) with variable_scope.variable_scope("ScheduledEmbedding"): # Return -1s where we did not sample, and sample_ids elsewhere select_sampler = bernoulli.Bernoulli(probs=sampling_probability, dtype=tf.bool) select_sample = select_sampler.sample(sample_shape=hps.batch_size) sample_id_sampler = categorical.Categorical( probs=output ) # equals to argmax{ Multinomial(output, total_count=1) }, our greedy search selection sample_ids = array_ops.where(select_sample, sample_id_sampler.sample(seed=123), gen_array_ops.fill([hps.batch_size], -1)) where_sampling = math_ops.cast(array_ops.where(sample_ids > -1), tf.int32) where_not_sampling = math_ops.cast(array_ops.where(sample_ids <= -1), tf.int32) if hps.greedy_scheduled_sampling: sample_ids = tf.argmax(output, axis=1, output_type=tf.int32) sample_ids_sampling = array_ops.gather_nd(sample_ids, where_sampling) cond = tf.less(sample_ids_sampling, vocab_size) # replace oov with unk sample_ids_sampling = tf.cast(cond, tf.int32) * sample_ids_sampling inputs_not_sampling = array_ops.gather_nd(inp, where_not_sampling) if hps.E2EBackProp: if hps.hard_argmax: greedy_search_prob, greedy_search_sample = tf.nn.top_k( output, k=hps.k) # (batch_size, k) greedy_search_prob_normalized = greedy_search_prob / tf.reshape( tf.reduce_sum(greedy_search_prob, axis=1), [-1, 1]) cond = tf.less(greedy_search_sample, vocab_size) # replace oov with unk greedy_search_sample = tf.cast(cond, tf.int32) * greedy_search_sample greedy_embedding = tf.nn.embedding_lookup( embedding, greedy_search_sample) normalized_embedding = tf.multiply( tf.reshape(greedy_search_prob_normalized, [hps.batch_size, hps.k, 1]), greedy_embedding) e2e_embedding = tf.reduce_mean(normalized_embedding, axis=1) else: e = [] greedy_search_prob, greedy_search_sample = soft_top_k( alpha, output, K=hps.k) # (batch_size, k), (k, vocab_size) greedy_search_prob_normalized = greedy_search_prob / tf.reshape( tf.reduce_sum(greedy_search_prob, axis=1), [-1, 1]) for _ in range(hps.k): a_k = greedy_search_sample[_] e_k = tf.matmul( tf.reshape(greedy_search_prob_normalized[:, _], [-1, 1]) * a_k, embedding) e.append(e_k) e2e_embedding = tf.reduce_sum(e, axis=0) # (batch_size, emb_dim) sampled_next_inputs = array_ops.gather_nd(e2e_embedding, where_sampling) else: if hps.hard_argmax: sampled_next_inputs = tf.nn.embedding_lookup( embedding, sample_ids_sampling) else: # using soft armax (greedy) proposed in: https://arxiv.org/abs/1704.06970 if not hps.greedy_scheduled_sampling: # Gumbel reparametrization trick: https://arxiv.org/abs/1704.06970 U = tf.random_uniform( (hps.batch_size, vocab_size), 10e-12, (1 - 10e-12)) # add a small number to avoid log(0) G = -tf.log(-tf.log(U)) else: G = tf.zeros((hps.batch_size, vocab_size)) #alpha_exp = tf.exp(alpha * (output_not_extended + G)) # (batch_size, vocab_size) #one_hot_scores = alpha_exp / tf.reduce_sum(alpha_exp, axis=1) #(batch_size, vocab_size) one_hot_scores = soft_argmax( alpha, (output + G)) #(batch_size, vocab_size) sampled_next_inputs = tf.matmul( one_hot_scores, embedding) #(batch_size, emb_size) base_shape = array_ops.shape(inp) result1 = array_ops.scatter_nd(indices=where_sampling, updates=sampled_next_inputs, shape=base_shape) result2 = array_ops.scatter_nd(indices=where_not_sampling, updates=inputs_not_sampling, shape=base_shape) return result1 + result2
def _groupwise_dnn_v2(features, labels, mode, params, config): """Defines the dnn for groupwise scoring functions.""" with ops.name_scope('transform'): context_features, per_example_features = _call_transform_fn( features, mode) def _score_fn(context_features, group_features, reuse): with variable_scope.variable_scope('group_score', reuse=reuse): return group_score_fn(context_features, group_features, mode, params, config) # Scatter/Gather per-example scores through groupwise comparison. Each # instance in a mini-batch will form a number of groups. Each groups of # examples are scored by 'score_fn' and socres for individual examples # accumulated over groups. with ops.name_scope('groupwise_dnn_v2'): with ops.name_scope('infer_sizes'): if labels is not None: batch_size, list_size = array_ops.unstack( array_ops.shape(labels)) is_valid = utils.is_label_valid(labels) else: # Infer batch_size and list_size from a feature. example_tensor_shape = array_ops.shape( next(six.itervalues(per_example_features))) batch_size = example_tensor_shape[0] list_size = example_tensor_shape[1] is_valid = utils.is_label_valid( array_ops.ones([batch_size, list_size])) if batch_size is None or list_size is None: raise ValueError('Invalid batch_size=%s or list_size=%s' % (batch_size, list_size)) # For each example feature, assume the shape is [batch_size, list_size, # feature_size], the groups are formed along the 2nd dim. Each group has a # 'group_size' number of indices in [0, list_size). Based on these # indices, we can gather the example feature into a sub-tensor for each # group. The total number of groups we have for a mini-batch is batch_size # * num_groups. Inside each group, we have a 'group_size' number of # examples. indices, mask = _form_group_indices_nd( is_valid, group_size, shuffle=(mode != model_fn.ModeKeys.PREDICT)) num_groups = array_ops.shape(mask)[1] with ops.name_scope('group_features'): # For context features, We have shape [batch_size * num_groups, ...]. large_batch_context_features = {} for name, value in six.iteritems(context_features): # [batch_size, 1, ...]. value = array_ops.expand_dims(value, axis=1) # [batch_size, num_groups, ...]. value = array_ops.gather(value, array_ops.zeros([num_groups], dtypes.int32), axis=1) # [batch_size * num_groups, ...] large_batch_context_features[ name] = utils.reshape_first_ndims( value, 2, [batch_size * num_groups]) # For example feature, we have shape [batch_size * num_groups, # group_size, ...]. large_batch_group_features = {} for name, value in six.iteritems(per_example_features): # [batch_size, num_groups, group_size, ...]. value = array_ops.gather_nd(value, indices) # [batch_size * num_groups, group_size, ...]. large_batch_group_features[ name] = utils.reshape_first_ndims( value, 3, [batch_size * num_groups, group_size]) # Do the inference and get scores for the large batch. # [batch_size * num_groups, group_size]. scores = _score_fn(large_batch_context_features, large_batch_group_features, reuse=False) with ops.name_scope('accumulate_scores'): scores = array_ops.reshape( scores, [batch_size, num_groups, group_size]) # Reset invalid scores to 0 based on mask. scores = array_ops.where( array_ops.gather(array_ops.expand_dims(mask, 2), array_ops.zeros([group_size], dtypes.int32), axis=2), scores, array_ops.zeros_like(scores)) # [batch_size, num_groups, group_size]. list_scores = array_ops.scatter_nd(indices, scores, [batch_size, list_size]) # Use average. list_scores /= math_ops.to_float(group_size) if mode == model_fn.ModeKeys.PREDICT: return list_scores else: features.update(context_features) features.update(per_example_features) return list_scores
def _replicate_rows(tensor, multiple): tensor_shape = tensor.shape.as_list() expanded_shape = [tensor_shape[0] * multiple, tensor_shape[1]] indices = constant_op.constant(_generate_indices(tensor_shape[0], multiple)) return array_ops.scatter_nd(indices, _tile_rows(tensor, multiple), expanded_shape)
def testUndefinedUpdatesShape(self): indices = array_ops.placeholder(dtypes.int32, shape=[2, 2, 2]) updates = array_ops.placeholder(dtypes.int32, shape=None) shape = constant_op.constant([2, 2, 2], dtypes.int32) array_ops.scatter_nd(indices, updates, shape)
def collapse_repeated(labels, seq_length, name=None): """Merge repeated labels into single labels. Args: labels: Tensor of shape (batch, max value in seq_length) seq_length: Tensor of shape (batch), sequence length of each batch element. name: A name for this `Op`. Defaults to "collapse_repeated_labels". Returns: tuple of Tensor of shape (batch, max_seq_length) with repeated labels collapsed and padded to max_seq_length, eg: [[A, A, B, B, A], [A, B, C, D, E]] => [[A, B, A, 0, 0], [A, B, C, D, E]] and int tensor of shape [batch] with new sequence lengths. """ with ops.name_scope(name, "collapse_repeated_labels", [labels, seq_length]): labels = ops.convert_to_tensor(labels, name="labels") seq_length = ops.convert_to_tensor(seq_length, name="seq_length") # Mask labels that don't equal previous label. label_mask = array_ops.concat([ array_ops.ones_like(labels[:, :1], dtypes.bool), math_ops.not_equal(labels[:, 1:], labels[:, :-1]) ], axis=1) # Filter labels that aren't in the original sequence. maxlen = _get_dim(labels, 1) seq_mask = array_ops.sequence_mask(seq_length, maxlen=maxlen) label_mask = math_ops.logical_and(label_mask, seq_mask) # Count masks for new sequence lengths. new_seq_len = math_ops.reduce_sum(math_ops.cast( label_mask, dtypes.int32), axis=1) # Mask indexes based on sequence length mask. new_maxlen = math_ops.reduce_max(new_seq_len) idx_mask = array_ops.sequence_mask(new_seq_len, maxlen=new_maxlen) # Flatten everything and mask out labels to keep and sparse indices. flat_labels = array_ops.reshape(labels, [-1]) flat_label_mask = array_ops.reshape(label_mask, [-1]) flat_idx_mask = array_ops.reshape(idx_mask, [-1]) idx = math_ops.range(_get_dim(flat_idx_mask, 0)) # Scatter to flat shape. flat = array_ops.scatter_nd(indices=array_ops.expand_dims( array_ops.boolean_mask(idx, flat_idx_mask), axis=1), updates=array_ops.boolean_mask( flat_labels, flat_label_mask), shape=array_ops.shape(flat_idx_mask)) # Reshape back to square batch. batch_size = _get_dim(labels, 0) new_shape = [batch_size, new_maxlen] return (array_ops.reshape(flat, new_shape), math_ops.cast(new_seq_len, seq_length.dtype))