def testReset(self): batch_size = 2 key_depth = 3 val_depth = 5 memory_size = 4 memory = transformer_memory.TransformerMemory(batch_size, key_depth, val_depth, memory_size) vals = tf.random_uniform([batch_size, memory_size, val_depth], minval=1.0) logits = tf.random_uniform([batch_size, memory_size], minval=1.0) update_op = memory.set(vals, logits) reset_op = memory.reset([1]) mem_vals, mem_logits = memory.get() assert_op1 = tf.assert_equal(mem_vals[0], vals[0]) assert_op2 = tf.assert_equal(mem_logits[0], logits[0]) with tf.control_dependencies([assert_op1, assert_op2]): all_zero1 = tf.reduce_sum(tf.abs(mem_vals[1])) all_zero2 = tf.reduce_sum(tf.abs(mem_logits[1])) with self.test_session() as session: session.run(tf.global_variables_initializer()) session.run(update_op) session.run(reset_op) zero1, zero2 = session.run([all_zero1, all_zero2]) self.assertAllEqual(0, zero1) self.assertAllEqual(0, zero2)
def area_range_to_index(area_range, length, max_area_width): """Computes the indices of each area in the area expansion. Args: area_range: tensor in shape of [batch_size, 2] length: a scalar tensor gives the length of the original feature space. max_area_width: a constant scalar. Returns: indices: area indices tensor in shape of [batch_size] """ with tf.control_dependencies([ tf.assert_equal(tf.rank(area_range), 2), tf.assert_equal(tf.shape(area_range)[1], 2) ]): area_range = tf.cast(area_range, tf.int32) target_size = area_range[:, 1] - area_range[:, 0] with tf.control_dependencies( [tf.assert_less(target_size, max_area_width + 1, summarize=100000)]): sizes = target_size - 1 start_length = length pre_end_length = length - sizes + 1 base = (start_length + pre_end_length) *\ (start_length - pre_end_length + 1) // 2 base = tf.where(tf.less_equal(target_size, 1), tf.zeros_like(target_size), base) offset = area_range[:, 0] return base + offset
def batch_gather(values, indices): """Gather slices from values. Args: values: a tensor in the shape of [batch_size, length, depth]. indices: a tensor in the shape of [batch_size, slice_count] where slice_count < length. Returns: a tensor in the shape of [batch_size, slice_count, depth]. """ with tf.control_dependencies([ tf.assert_equal(tf.rank(values), 3, message="values"), tf.assert_equal(tf.rank(indices), 2, message="indices"), tf.assert_equal(tf.shape(values)[0], tf.shape(indices)[0], message="batch"), ]): shape = common_layers.shape_list(indices) depth = common_layers.shape_list(values)[-1] batch_indices = tf.reshape( tf.tile(tf.expand_dims(tf.range(shape[0]), [1]), [1, shape[1]]), [-1, 1]) indices = tf.concat( [batch_indices, tf.cast(tf.reshape(indices, [-1, 1]), tf.int32)], axis=-1) slices = tf.gather_nd(params=values, indices=indices) return tf.reshape(slices, [shape[0], shape[1], depth])
def convert_padding_mask_to_attention_mask(sequence, padding_mask): """Given a padded input tensor of sequences and a boolean mask for each position in the sequence, returns a 3D boolean mask for use in attention. Args: sequence (tf.Tensor): Tensor of shape [batch_size, sequence_length_1, ndim] padding_mask (tf.Tensor[bool]): Tensor of shape [batch_size, sequence_length_2] Returns: tf.Tensor[bool]: Tensor of shape [batch_size, sequence_length_1, sequence_length_2] """ batch_assert = tf.assert_equal( tf.shape(padding_mask)[0], tf.shape(sequence)[0], message='batch size mismatch between input sequence and \ padding_mask') rank_assert = tf.assert_equal( tf.rank(padding_mask), 2, message='Can only convert 2D position mask to 3D attention mask') with tf.control_dependencies([batch_assert, rank_assert]): attention_mask = tf.tile(padding_mask[:, None, :], (1, tf.shape(sequence)[1], 1)) return attention_mask
def convert_sequence_length_to_sequence_mask(sequence, sequence_lengths): """Given a padded input tensor of sequences and a tensor of lengths, returns a boolean mask for each position in the sequence indicating whether or not that position is padding. Args: sequence (tf.Tensor): Tensor of shape [batch_size, sequence_length, ndim] sequence_lengths (tf.Tensor[int]): Tensor of shape [batch_size] Returns: tf.Tensor[bool]: Tensor of shape [batch_size, sequence_length] """ batch_assert = tf.assert_equal( tf.shape(sequence_lengths)[0], tf.shape(sequence)[0], message='batch size mismatch between input sequence and \ sequence_lengths') rank_assert = tf.assert_equal( tf.rank(sequence_lengths), 1, message='Can only convert 1D sequence_lengths to 2D mask') with tf.control_dependencies([batch_assert, rank_assert]): indices = tf.tile( tf.range(tf.shape(sequence)[1])[None, :], (tf.shape(sequence_lengths)[0], 1)) mask = indices < sequence_lengths[:, None] return mask
def call(self, inputs): """Returns action distribution, given a state.""" act_mu = self.mu(inputs) act_sig = tf.exp(tf.tile(self.logsig, [tf.shape(act_mu)[0], 1])) tf.assert_equal(act_mu.shape, act_sig.shape) act_dist = self.dist(act_mu, act_sig) return act_dist
def predict_refs(logits, starts, ends): """Outputs the refs based on area predictions.""" with tf.control_dependencies([ tf.assert_equal(tf.rank(logits), 3), tf.assert_equal(tf.rank(starts), 2), tf.assert_equal(tf.rank(ends), 2) ]): predicted_areas = tf.argmax(logits, -1) return area_utils.area_to_refs(starts, ends, predicted_areas)
def test_random_normal(self, mock_stateless_random_normal): _ = dynamics.random_normal(shape=[3, 1], i=41 / 5, key=9) _, call_args = mock_stateless_random_normal.call_args assert_ops = [ tf.assert_equal(tf.stack([9, 8]), call_args['seed']), tf.assert_equal([3, 1], call_args['shape']) ] with self.session() as sess: sess.run(assert_ops)
def compute_loss(self, unreduced_loss): """Computes scaled loss based on mask out size.""" # construct mask to identify zero padding that was introduced to # make the batch rectangular batch_duration = tf.shape(self.pianorolls)[1] indices = tf.to_float(tf.range(batch_duration)) pad_mask = tf.to_float( indices[None, :, None, None] < self.lengths[:, None, None, None]) # construct mask and its complement, respecting pad mask mask = pad_mask * self.masks unmask = pad_mask * (1. - self.masks) # Compute numbers of variables # #timesteps * #variables per timestep variable_axis = 3 if self.hparams.use_softmax_loss else 2 dd = (self.lengths[:, None, None, None] * tf.to_float(tf.shape(self.pianorolls)[variable_axis])) reduced_dd = tf.reduce_sum(dd) # Compute numbers of variables to be predicted/conditioned on mask_size = tf.reduce_sum(mask, axis=[1, variable_axis], keep_dims=True) unmask_size = tf.reduce_sum(unmask, axis=[1, variable_axis], keep_dims=True) unreduced_loss *= pad_mask if self.hparams.rescale_loss: unreduced_loss *= dd / mask_size # Compute average loss over entire set of variables self.loss_total = tf.reduce_sum(unreduced_loss) / reduced_dd # Compute separate losses for masked/unmasked variables # NOTE: indexing the pitch dimension with 0 because the mask is constant # across pitch. Except in the sigmoid case, but then the pitch dimension # will have been reduced over. self.reduced_mask_size = tf.reduce_sum(mask_size[:, :, 0, :]) self.reduced_unmask_size = tf.reduce_sum(unmask_size[:, :, 0, :]) assert_partition_op = tf.group( tf.assert_equal(tf.reduce_sum(mask * unmask), 0.), tf.assert_equal(self.reduced_mask_size + self.reduced_unmask_size, reduced_dd)) with tf.control_dependencies([assert_partition_op]): self.loss_mask = (tf.reduce_sum(mask * unreduced_loss) / self.reduced_mask_size) self.loss_unmask = (tf.reduce_sum(unmask * unreduced_loss) / self.reduced_unmask_size) # Check which loss to use as objective function. self.loss = (self.loss_mask if self.hparams.optimize_mask_only else self.loss_total)
def expand_first_dimension(inputs, dims): """Expands `K-d` tensor along first dimension to be a `(K+n-1)-d` tensor. Converts `inputs` with shape [D0, D1, ..., D(K-1)] into a tensor of shape [dims[0], dims[1], ..., dims[-1], D1, ..., D(k-1)]. Example: `inputs` is a tensor with shape [50, 20, 20, 3]. new_tensor = expand_first_dimension(inputs, [10, 5]). new_tensor.shape -> [10, 5, 20, 20, 3]. Args: inputs: a tensor with shape [D0, D1, ..., D(K-1)]. dims: List with new dimensions to expand first axis into. The length of `dims` is typically 2 or larger. Returns: a tensor with shape [dims[0], dims[1], ..., dims[-1], D1, ..., D(k-1)]. """ inputs_shape = combined_static_and_dynamic_shape(inputs) expanded_shape = tf.stack(dims + inputs_shape[1:]) # Verify that it is possible to expand the first axis of inputs. assert_op = tf.assert_equal( inputs_shape[0], tf.reduce_prod(tf.stack(dims)), message=( 'First dimension of `inputs` cannot be expanded into provided ' '`dims`')) with tf.control_dependencies([assert_op]): inputs_reshaped = tf.reshape(inputs, expanded_shape) return inputs_reshaped
def assert_shape_equal_along_first_dimension(shape_a, shape_b): """Asserts that shape_a and shape_b are the same along the 0th-dimension. If the shapes are static, raises a ValueError when the shapes mismatch. If the shapes are dynamic, raises a tf InvalidArgumentError when the shapes mismatch. Args: shape_a: a list containing shape of the first tensor. shape_b: a list containing shape of the second tensor. Returns: Either a tf.no_op() when shapes are all static and a tf.assert_equal() op when the shapes are dynamic. Raises: ValueError: When shapes are both static and unequal. """ if isinstance(shape_a[0], int) and isinstance(shape_b[0], int): if shape_a[0] != shape_b[0]: raise ValueError('Unequal first dimension {}, {}'.format( shape_a[0], shape_b[0])) else: return tf.no_op() else: return tf.assert_equal(shape_a[0], shape_b[0])
def assert_mvn_target_conservation(event_size, batch_size, **kwargs): initialization = tfd.MultivariateNormalFullCovariance( loc=tf.zeros(event_size), covariance_matrix=tf.eye(event_size)).sample(batch_size, seed=4) samples, leapfrogs = run_nuts_chain(event_size, batch_size, num_steps=1, initial_state=initialization, **kwargs) answer = samples[0][-1] check_cdf_agrees = ( st.assert_multivariate_true_cdf_equal_on_projections_two_sample( answer, initialization, num_projections=100, false_fail_rate=1e-6)) check_sample_shape = tf1.assert_equal( tf.shape(input=answer)[0], batch_size) unique, _ = tf.unique(leapfrogs[0]) check_leapfrogs_vary = tf1.assert_greater_equal( tf.shape(input=unique)[0], 3) avg_leapfrogs = tf.math.reduce_mean(input_tensor=leapfrogs[0]) check_leapfrogs = tf1.assert_greater_equal( avg_leapfrogs, tf.constant(4, dtype=avg_leapfrogs.dtype)) movement = tf.linalg.norm(tensor=answer - initialization, axis=-1) # This movement distance (0.3) was copied from the univariate case. check_movement = tf1.assert_greater_equal( tf.reduce_mean(input_tensor=movement), 0.3) check_enough_power = tf1.assert_less( st.min_discrepancy_of_true_cdfs_detectable_by_dkwm_two_sample( batch_size, batch_size, false_fail_rate=1e-8, false_pass_rate=1e-6), 0.055) return (check_cdf_agrees, check_sample_shape, check_leapfrogs_vary, check_leapfrogs, check_movement, check_enough_power)
def _training(self): """Perform multiple training iterations of both policy and value baseline. Training on the episodes collected in the memory. Reset the memory afterwards. Always returns a summary string. Returns: Summary tensor. """ with tf.name_scope('training'): assert_full = tf.assert_equal(self._memory_index, self._config.update_every) with tf.control_dependencies([assert_full]): data = self._memory.data() (observ, action, old_mean, old_logstd, reward), length = data with tf.control_dependencies([tf.assert_greater(length, 0)]): length = tf.identity(length) observ = self._observ_filter.transform(observ) reward = self._reward_filter.transform(reward) update_summary = self._perform_update_steps( observ, action, old_mean, old_logstd, reward, length) with tf.control_dependencies([update_summary]): penalty_summary = self._adjust_penalty(observ, old_mean, old_logstd, length) with tf.control_dependencies([penalty_summary]): clear_memory = tf.group(self._memory.clear(), self._memory_index.assign(0)) with tf.control_dependencies([clear_memory]): weight_summary = utility.variable_summaries( tf.trainable_variables(), self._config.weight_summaries) return tf.summary.merge( [update_summary, penalty_summary, weight_summary])
def _check_batch_sizes(self, factor): """Checks that the batch size(s) for a factor matches the reference value.""" # Should make these messages use quote characters instead of parentheses # when the bug with quote character rendering in assertion messages is # fixed. See b/129476712 if self._batch_size is None: batch_size = self.factors[0].batch_size() string = ("Batch size {} for factor (" + factor.name + ") of type " + utils.cls_name(factor) + " did not match value {} used by " "factor (" + self.factors[0].name + ") of type " + utils.cls_name(self.factors[0]) + ".") else: batch_size = self._batch_size string = ("Batch size {} for factor (" + factor.name + ") of type " + utils.cls_name(factor) + " did not match value {} which was " "passed to optimizer/estimator.") factor_batch_size = factor.batch_size() if isinstance(batch_size, int) and isinstance(factor_batch_size, int): if factor_batch_size != batch_size: raise ValueError(string.format(factor_batch_size, batch_size)) return factor.check_partial_batch_sizes() else: message = string.format("(x)", "(y)") with tf.control_dependencies([factor.check_partial_batch_sizes()]): return tf.assert_equal(factor_batch_size, batch_size, message=message)
def assert_finite(x, data=None, summarize=None, message=None, name=None): """Assert all elements of `x` are finite. Args: x: Numeric `Tensor`. data: The tensors to print out if the condition is False. Defaults to error message and first few entries of `x`. summarize: Print this many entries of each tensor. message: A string to prefix to the default message. name: A name for this operation (optional). Defaults to "assert_finite". Returns: Op raising `InvalidArgumentError` unless `x` has specified rank or lower. If static checks determine `x` has correct rank, a `no_op` is returned. Raises: ValueError: If static checks determine `x` has wrong rank. """ with tf.name_scope(name or 'assert_finite'): x_ = tf.get_static_value(x) if x_ is not None: if ~np.all(np.isfinite(x_)): raise ValueError(message) return x assertion = tf1.assert_equal(tf.math.is_finite(x), tf.ones_like(x, tf.bool), data=data, summarize=summarize, message=message) with tf.control_dependencies([assertion]): return tf.identity(x)
def assert_shape_equal(shape_a, shape_b): """Asserts that shape_a and shape_b are equal. If the shapes are static, raises a ValueError when the shapes mismatch. If the shapes are dynamic, raises a tf InvalidArgumentError when the shapes mismatch. Args: shape_a: a list containing shape of the first tensor. shape_b: a list containing shape of the second tensor. Returns: Either a tf.no_op() when shapes are all static and a tf.assert_equal() op when the shapes are dynamic. Raises: ValueError: When shapes are both static and unequal. """ if (all(isinstance(dim, int) for dim in shape_a) and all(isinstance(dim, int) for dim in shape_b)): if shape_a != shape_b: raise ValueError('Unequal shapes {}, {}'.format(shape_a, shape_b)) else: return tf.no_op() else: return tf.assert_equal(shape_a, shape_b)
def create_id3_embedding(videos): """Embeds the given videos using the Inflated 3D Convolution network. Downloads the graph of the I3D from tf.hub and adds it to the graph on the first call. Args: videos: <float32>[batch_size, num_frames, height=224, width=224, depth=3]. Expected range is [-1, 1]. Returns: embedding: <float32>[batch_size, embedding_size]. embedding_size depends on the model used. Raises: ValueError: when a provided embedding_layer is not supported. """ batch_size = 16 module_spec = "https://tfhub.dev/deepmind/i3d-kinetics-400/1" # Making sure that we import the graph separately for # each different input video tensor. module_name = "fvd_kinetics-400_id3_module_" + six.ensure_str( videos.name).replace(":", "_") assert_ops = [ tf.Assert( tf.reduce_max(videos) <= 1.001, ["max value in frame is > 1", videos]), tf.Assert( tf.reduce_min(videos) >= -1.001, ["min value in frame is < -1", videos]), tf.assert_equal(tf.shape(videos)[0], batch_size, ["invalid frame batch size: ", tf.shape(videos)], summarize=6), ] with tf.control_dependencies(assert_ops): videos = tf.identity(videos) module_scope = "%s_apply_default/" % module_name # To check whether the module has already been loaded into the graph, we look # for a given tensor name. If this tensor name exists, we assume the function # has been called before and the graph was imported. Otherwise we import it. # Note: in theory, the tensor could exist, but have wrong shapes. # This will happen if create_id3_embedding is called with a frames_placehoder # of wrong size/batch size, because even though that will throw a tf.Assert # on graph-execution time, it will insert the tensor (with wrong shape) into # the graph. This is why we need the following assert. video_batch_size = int(videos.shape[0]) assert video_batch_size in [batch_size, -1, None], "Invalid batch size" tensor_name = module_scope + "RGB/inception_i3d/Mean:0" if not _is_in_graph(tensor_name): i3d_model = hub.Module(module_spec, name=module_name) i3d_model(videos) # gets the kinetics-i3d-400-logits layer tensor_name = module_scope + "RGB/inception_i3d/Mean:0" tensor = tf.get_default_graph().get_tensor_by_name(tensor_name) return tensor
def _apply_reapated_text_masking( config: RetrieverConfig, question_hash: tf.Tensor, question_hash_transposed: tf.Tensor, labels: tf.Tensor, logits: tf.Tensor, ) -> tf.Tensor: """Applies repated text masking. Args: config: Retriever config. question_hash: <int64>[global_batch_size, 1] question_hash_transposed: <int64>[1, batch_size] labels: <int64>[batch_size, global_batch_size * num_tables] logits: <float>[batch_size, global_batch_size * num_tables] Returns: Masked logits (same shape / dtype). """ # Make sure not all hashes are 0. # This indicates the "question_hash" feature wasn't set. assert_op = tf.assert_equal( tf.math.reduce_all(tf.math.equal(question_hash, 0)), [False]) with tf.control_dependencies([assert_op]): logging.vlog(2, "question_hash: %s", question_hash) logging.vlog(2, "question_hash_transposed: %s", question_hash_transposed) logging.vlog(2, "labels: %s", labels) logging.vlog(2, "logits: %s", logits) # <bool>[batch_size, global_batch_size] repeated_texts = tf.math.equal(question_hash, question_hash_transposed) if config.use_mined_negatives: batch_size = repeated_texts.shape[0] global_batch_size = repeated_texts.shape[1] num_tables = logits.shape[1] // global_batch_size # <bool>[batch_size, global_batch_size * num_tables] repeated_texts = tf.concat([ repeated_texts, tf.zeros(shape=(batch_size, (num_tables - 1) * global_batch_size), dtype=tf.bool) ], axis=1) repeated_texts = ( repeated_texts # Makes sure original correct question pair isn't masked & tf.math.equal(labels, 0)) logging.vlog(2, "repeated texts: %s", repeated_texts) ops = [] if logging.vlog_is_on(2): ops.append( tf.print( "repeated texts content:", question_hash, repeated_texts, output_stream=logging.info, )) with tf.control_dependencies(ops): return tf.where(repeated_texts, tf.zeros_like(logits) - _INF, logits)
def assert_equal(*args, **kwargs): """ Wrapper for tf.assert_equal. Overrides tf.device so that the assert always goes on CPU. The unwrapped version raises an exception if used with tf.device("/GPU:x"). """ with tf.device("/CPU:0"): return tf.assert_equal(*args, **kwargs)
def verify_example_ids(self): tensor = tf.strings.to_hash_bucket_fast(self._example_ids, 2**31 - 1) if self._role == 'leader': self.send('_verify_example_ids', tensor) else: recv_tensor = self.recv('_verify_example_ids', tensor.dtype) op = tf.assert_equal(tensor, recv_tensor) self._train_ops.append(op)
def symbols_to_logits(_, i, states): # We have to assert the values of state inline here since we can't fetch # them out of the loop! with tf.control_dependencies( [tf.assert_equal(states["state"], expected_states[i])]): logits = tf.to_float(tf.log(probabilities[i, :])) states["state"] += tf.constant([[3.], [7.]]) return logits, states
def assert_shape_equal(shape_a, shape_b): if (all(isinstance(dim, int) for dim in shape_a) and all(isinstance(dim, int) for dim in shape_b)): if shape_a != shape_b: raise ValueError('Unequal shapes {}, {}'.format(shape_a, shape_b)) else: return tf.no_op() else: return tf.assert_equal(shape_a, shape_b)
def tflite_compat_mel(wav_audio, hparams): """EXPERIMENTAL: Log mel spec with ops that can be made TFLite compatible.""" samples, decoded_sample_rate = tf.audio.decode_wav(wav_audio, desired_channels=1) samples = tf.squeeze(samples, axis=1) # Ensure that we decoded the samples at the expected sample rate. with tf.control_dependencies( [tf.assert_equal(decoded_sample_rate, hparams.sample_rate)]): return tflite_compat_mel_from_samples(samples, hparams)
def symbols_to_logits(ids, _, states): pos = tf.shape(ids)[1] - 1 # We have to assert the values of state inline here since we can't fetch # them out of the loop! with tf.control_dependencies( [tf.assert_equal(states["state"], expected_states[pos])]): logits = tf.to_float(tf.log(probabilities[pos, :])) states["state"] += 1 return logits, states
def validate_equal_last_dim(tensor_a, tensor_b, message): if tensor_a.shape.is_fully_defined( ) and tensor_b.shape.is_fully_defined(): if tensor_a.shape[-1] != tensor_b.shape[-1]: raise ValueError(message) elif validate_args: assertions.append( tf1.assert_equal(tf.shape(tensor_a)[-1], tf.shape(tensor_b)[-1], message=message))
def __readImages(self, filename): image_string = tf.read_file( filename) #Gets a string tensor from a file decodedInput = tf.image.decode_image( image_string) #Decode a string tensor as image floatInput = tf.image.convert_image_dtype( decodedInput, dtype=tf.float32) #Transform image to float32 assertion = tf.assert_equal(tf.shape(floatInput)[-1], 3, message="image does not have 3 channels") with tf.control_dependencies([assertion]): floatInput.set_shape([None, None, 3]) inputShape = floatInput.get_shape() if self.mode == "eval": #If the inputs are only the number of pictures declared blackTargets = tf.zeros([ self.inputImageSize, self.inputImageSize * self.nbTargetsToRead, 3 ]) floatInput = tf.concat([floatInput, blackTargets], axis=1) floatInputSplit = tf.split( floatInput, self.nbTargetsToRead + self.inputNumbers, axis=1, name="Split_input_data" ) #Splitted we get a list of nbTargets + inputNumbers images #Sets the inputs and outputs depending on the order of images if self.which_direction == "AtoB": inputs = floatInputSplit[:self.inputNumbers] targets = floatInputSplit[self.inputNumbers:] elif self.which_direction == "BtoA": inputs = floatInputSplit[self.inputNumbers:] targets = floatInputSplit[:self.inputNumbers] else: raise ValueError("Invalid direction") gammadInputs = inputs inputs = [tf.pow(input, 2.2) for input in inputs] #correct for the gamma #If we want to log the inputs, we do it here if self.logInput: inputs = [helpers.logTensor(input) for input in inputs] #The preprocess function puts the vectors value between [-1; 1] from [0;1] inputs = [helpers.preprocess(input) for input in inputs] #gammadInputs = [helpers.preprocess(gammadInput) for gammadInput in gammadInputs] targets = [helpers.preprocess(target) for target in targets] #We used to resize inputs and targets here, we have no functional need for it. Will see if there is a technical need to define the actual size. return filename, inputs, targets, gammadInputs
def _maybe_mask(m, seq_len_mask): """Mask the sequence with m.""" rank = m.get_shape().ndims rank = rank if rank is not None else tf.rank(m) extra_ones = tf.ones(rank - 2, dtype=tf.int32) m_batch_size = dimension_value(m.shape[0]) or tf.shape(m)[0] with tf.control_dependencies( [tf.assert_equal(seq_len_batch_size, m_batch_size, message="batch")]): seq_len_mask = tf.reshape( seq_len_mask, tf.concat((tf.shape(seq_len_mask), extra_ones), 0)) return m * seq_len_mask
def _scan_fn(*_): exchange = exchange_proposed_fn(num_replica, seed) flat_replicas = tf.reshape(exchange, [-1]) with tf.control_dependencies([ tf1.assert_equal( tf.size(input=flat_replicas), tf.size(input=tf.unique(flat_replicas)[0])), tf1.assert_greater_equal(flat_replicas, 0), tf1.assert_less(flat_replicas, num_replica), ]): return tf.shape(input=exchange)[0]
def span_embedding(encoder_input_length, area_encodings, spans, hparams): """Computes the embedding for each span. (TODO: liyang): comment shapes.""" with tf.control_dependencies([tf.assert_equal(tf.rank(area_encodings), 3)]): area_indices = area_utils.area_range_to_index( area_range=tf.reshape(spans, [-1, 2]), length=encoder_input_length, max_area_width=hparams.max_span) return area_utils.batch_gather( area_encodings, tf.reshape(area_indices, [tf.shape(spans)[0], tf.shape(spans)[1]]))
def query_area(query, area_encodings, area_bias): """Predicts a range of tokens based on the query. Args: query: a Tensor of shape [batch_size, length, depth] area_encodings: a tensor in shape of [batch_size, num_areas, depth] area_bias: a tensor in shape of [batch_size, num_areas]. Returns: the logits to each area. """ with tf.control_dependencies([ tf.assert_equal(tf.rank(query), 3), tf.assert_equal(tf.rank(area_encodings), 3), tf.assert_equal(tf.shape(query)[-1], tf.shape(area_encodings)[-1]), tf.assert_equal(tf.rank(area_bias), 2) ]): dot_products = tf.matmul(query, tf.transpose(area_encodings, [0, 2, 1])) area_logits = dot_products + tf.expand_dims(area_bias, 1) return area_logits