def get_initialized_params(self, trainable=False, scope="embedding", reuse=False): """Returns a variable with the embeddings. Unlike `get_params` this does not require running a Scaffold to initialize the variable, however this method is not compatible with `tf.SavedModel` since it uses a `tf.py_func` to initialize the embedddings variable. Args: trainable: Boolean indicating whether the params should be trainable. scope: The name of the inner-most scope for the params. reuse: Boolean indicating whether to reuse params in the same scope. Returns: embedding_weights: The embedding weights. """ # Hide `self._idx2emb` behind tf.py_func so its does not get serialized as # as part of the graph and blow up our log sizes. init_value = tf.py_func(lambda: self._idx2emb, [], tf.float32, False) init_value.set_shape([len(self._idx2emb), self._dims]) with tf.variable_scope(scope, reuse=reuse): if trainable: embedding_weights = tf.get_variable( "embedding_weights", initializer=init_value) else: # Local variable so the embeddings won't get dumped into the checkpoints embedding_weights = tf.get_local_variable( "embedding_weights", initializer=init_value) return embedding_weights
def load_ragged_matrix(var_name, checkpoint_path): """Load sparse matrix from checkpoint.""" with tf.gfile.Open(checkpoint_path + ".info") as f: num_row, num_nnz = [int(xx) for xx in f.read().split()] tf_data = tf.get_local_variable( var_name + "_data", shape=[num_nnz], dtype=tf.float32, use_resource=True) tf_indices = tf.get_local_variable( var_name + "_indices", shape=[num_nnz], dtype=tf.int64, use_resource=True) tf_rowsplits = tf.get_local_variable( var_name + "_rowsplits", shape=[num_row + 1], dtype=tf.int64, use_resource=True) init_from_checkpoint( checkpoint_path, target_variables=[tf_data, tf_indices, tf_rowsplits]) return tf_data, tf_indices, tf_rowsplits
def load_sparse_matrix(var_name, checkpoint_path): """Load sparse matrix from checkpoint.""" with tf.gfile.Open(checkpoint_path + ".info") as f: num_nnz = int(f.read()) tf_data = tf.get_local_variable( var_name + "_data", shape=[num_nnz], dtype=tf.float32, use_resource=True) tf_indices = tf.get_local_variable( var_name + "_indices", shape=[num_nnz, 2], dtype=tf.int64, use_resource=True) tf_shape = tf.get_local_variable( var_name + "_shape", shape=[2], dtype=tf.int64, use_resource=True) init_from_checkpoint( checkpoint_path, target_variables=[tf_data, tf_indices, tf_shape]) tf_sp = tf.SparseTensor(tf_indices, tf_data, tf_shape) return tf_sp
def create_mips_searcher(var_name, checkpoint_path, num_neighbors): """Create searcher for returning top-k closest elements.""" tf_db = load_database(var_name, None, checkpoint_path) with tf.control_dependencies([tf_db.initializer]): mips_init_barrier = tf.constant(True) # Make sure DB is initialized. tf.get_local_variable("mips_init_barrier", initializer=mips_init_barrier) def _search(query): with tf.device("/cpu:0"): distance = tf.matmul(query, tf_db, transpose_b=True) topk_dist, topk_idx = tf.nn.top_k(distance, num_neighbors) topk_dist.set_shape([query.shape[0], num_neighbors]) topk_idx.set_shape([query.shape[0], num_neighbors]) return topk_dist, topk_idx return tf_db, _search
def load_database(var_name, shape, checkpoint_path, dtype=tf.float32): """Load variable from checkpoint.""" if shape is None: reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path) var_to_shape_map = reader.get_variable_to_shape_map() shape = var_to_shape_map[var_name] tf_db = tf.get_local_variable( var_name, shape=shape, dtype=dtype, use_resource=True) init_from_checkpoint(checkpoint_path, target_variables=[tf_db]) return tf_db
def build_logits(data_ops, embed_layer, rnn_core, output_linear, name_prefix): """This is the core model logic. Unrolls a Bayesian RNN over the given sequence. Args: data_ops: A `sequence_data.SequenceDataOps` namedtuple. embed_layer: A `snt.Embed` instance. rnn_core: A `snt.RNNCore` instance. output_linear: A `snt.Linear` instance. name_prefix: A string to use to prefix local variable names. Returns: A 3D time-major tensor representing the model's logits for a sequence of predictions. Shape `[time_steps, batch_size, vocab_size]`. """ # Embed the input index sequence. embedded_input_seq = snt.BatchApply(embed_layer, name="input_embed_seq")( data_ops.sparse_obs) # Construct variables for holding the RNN state. initial_rnn_state = nest.map_structure( lambda t: tf.get_local_variable( # pylint: disable long lambda warning "{}/rnn_state/{}".format(name_prefix, t.op.name), initializer=t), rnn_core.initial_state(FLAGS.batch_size)) assign_zero_rnn_state = nest.map_structure( lambda x: x.assign(tf.zeros_like(x)), initial_rnn_state) assign_zero_rnn_state = tf.group(*nest.flatten(assign_zero_rnn_state)) # Unroll the RNN core over the sequence. rnn_output_seq, rnn_final_state = tf.nn.dynamic_rnn( cell=rnn_core, inputs=embedded_input_seq, initial_state=initial_rnn_state, time_major=True) # Persist the RNN state for the next unroll. update_rnn_state = nest.map_structure(tf.assign, initial_rnn_state, rnn_final_state) with tf.control_dependencies(nest.flatten(update_rnn_state)): rnn_output_seq = tf.identity(rnn_output_seq, name="rnn_output_seq") output_logits = snt.BatchApply(output_linear, name="output_embed_seq")(rnn_output_seq) return output_logits, assign_zero_rnn_state
def load_scann_searcher(var_name, checkpoint_path, num_neighbors, dimensions_per_block=2, num_leaves=1000, num_leaves_to_search=100, training_sample_size=100000): """Load scann searcher from checkpoint.""" with tf.device("/cpu:0"): np_db = tf.train.load_checkpoint(checkpoint_path).get_tensor(var_name) init_db = tf.py_func(lambda: np_db, [], tf.float32) init_db.set_shape(np_db.shape) tf_db = tf.get_local_variable(var_name, initializer=init_db) builder = ScannBuilder(db=tf_db, num_neighbors=num_neighbors, distance_measure="dot_product") builder = builder.tree(num_leaves=num_leaves, num_leaves_to_search=num_leaves_to_search, training_sample_size=training_sample_size) builder = builder.score_ah(dimensions_per_block=dimensions_per_block) searcher = builder.create_tf() return tf_db, searcher
def __init__(self, train_batch_size=4096, test_chain_batch_size=4096, bijector="iaf", log_dir="/tmp/neutra", base_learning_rate=1e-3, q_base_scale=1., learning_rate_schedule=[[6000, 1e-1]]): target, target_spec = GetTargetSpec() self.target = target self.target_spec = target_spec with gin.config_scope("train"): train_target, train_target_spec = GetTargetSpec() self.train_target = train_target self.train_target_spec = train_target_spec if bijector == "rnvp": bijector_fn = tf.make_template("bijector", MakeRNVPBijectorFn, num_dims=self.target_spec.num_dims) elif bijector == "iaf": bijector_fn = tf.make_template("bijector", MakeIAFBijectorFn, num_dims=self.target_spec.num_dims) elif bijector == "affine": bijector_fn = tf.make_template("bijector", MakeAffineBijectorFn, num_dims=self.target_spec.num_dims) else: bijector_fn = lambda *args, **kwargs: tfb.Identity() self.train_bijector = bijector_fn(train=True) self.bijector = bijector_fn(train=False) if train_target_spec.bijector is not None: print("Using train target bijector") self.train_bijector = tfb.Chain( [train_target_spec.bijector, self.train_bijector]) if target_spec.bijector is not None: print("Using target bijector") self.bijector = tfb.Chain([target_spec.bijector, self.bijector]) q_base = tfd.Independent( tfd.Normal(loc=tf.zeros(self.target_spec.num_dims), scale=q_base_scale * tf.ones(self.target_spec.num_dims)), 1) self.q_x_train = tfd.TransformedDistribution(q_base, self.train_bijector) self.q_x = tfd.TransformedDistribution(q_base, self.bijector) # Params self.train_batch_size = int(train_batch_size) self.test_chain_batch_size = tf.placeholder_with_default( test_chain_batch_size, [], "test_chain_batch_size") self.test_batch_size = tf.placeholder_with_default( 16384 * 8, [], "test_batch_size") self.test_num_steps = tf.placeholder_with_default( 1000, [], "test_num_steps") self.test_num_leapfrog_steps = tf.placeholder_with_default( tf.to_int32(2), [], "test_num_leapfrog_steps") self.test_step_size = tf.placeholder_with_default( 0.1, [], "test_step_size") # Test self.neutra_outputs = MakeNeuTra( target=self.target, q=self.q_x, batch_size=self.test_chain_batch_size, num_steps=self.test_num_steps, num_leapfrog_steps=self.test_num_leapfrog_steps, step_size=self.test_step_size, ) self.z_chain = tf.reshape( self.bijector.inverse( tf.reshape(self.neutra_outputs.x_chain, [-1, self.target_spec.num_dims])), tf.shape(self.neutra_outputs.x_chain)) self.target_samples = self.target.sample(self.test_batch_size) self.target_z = self.bijector.inverse(self.target_samples) self.q_samples = self.q_x.sample(self.test_batch_size) self.target_cov = utils.Covariance(self.target_samples) self.target_eigvals, self.target_eigvecs = tf.linalg.eigh( self.target_cov) self.cached_target_eigvals = tf.get_local_variable( "cached_target_eigvals", self.target_eigvals.shape, initializer=tf.zeros_initializer()) self.cached_target_eigvecs = tf.get_local_variable( "cached_target_eigvecs", self.target_eigvecs.shape, initializer=tf.zeros_initializer()) self.cached_target_stats_update_op = [ self.cached_target_eigvals.assign(self.target_eigvals), self.cached_target_eigvecs.assign(self.target_eigvecs), tf.print("Assigning target stats") ] def variance(x): x -= tf.reduce_mean(x, 0, keep_dims=True) x = tf.square(x) return x def rotated_variance(x): x2 = tf.reshape(x, [-1, self.target_spec.num_dims]) x2 -= tf.reduce_mean(x2, 0, keep_dims=True) x2 = tf.matmul(x2, self.cached_target_eigvecs) x2 = tf.square(x2) return tf.reshape(x2, tf.shape(x)) functions = [ ("mean", tf.identity), # ("var", variance), ("square", tf.square), # ("rot_square", rot_square), # ("rot_var", rotated_variance), ] self.cached_target_mean = {} self.cached_target_mean_update_op = [ tf.print("Assigning target means.") ] self.neutra_stats = {} self.q_stats = {} for name, f in functions: target_mean = tf.reduce_mean(f(self.target_samples), 0) cached_target_mean = tf.get_local_variable(name + "_cached_mean", target_mean.shape) if self.target_spec.stats is not None: self.cached_target_mean_update_op.append( cached_target_mean.assign(self.target_spec.stats[name])) else: self.cached_target_mean_update_op.append( cached_target_mean.assign(target_mean)) self.cached_target_mean[name] = cached_target_mean self.q_stats[name] = ComputeQStats(f(self.q_samples), cached_target_mean) self.neutra_stats[name] = ComputeChainStats( f(self.neutra_outputs.x_chain), cached_target_mean, self.test_num_leapfrog_steps) # Training self.train_q_samples = self.q_x_train.sample(self.train_batch_size) self.train_log_q_x = self.q_x_train.log_prob(self.train_q_samples) self.kl_q_p = tf.reduce_mean( self.train_log_q_x - self.target.log_prob(self.train_q_samples)) loss = self.kl_q_p reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) if reg_losses: tf.logging.info("Regularizing.") loss += tf.add_n(reg_losses) self.loss = tf.check_numerics(loss, "Loss has NaNs") self.global_step = tf.train.get_or_create_global_step() steps, factors = list(zip(*learning_rate_schedule)) learning_rate = base_learning_rate * tf.train.piecewise_constant( self.global_step, steps, [1.0] + list(factors)) opt = tf.train.AdamOptimizer(learning_rate=learning_rate) self.train_op = opt.minimize(self.loss, global_step=self.global_step) tf.summary.scalar("kl_q_p", self.kl_q_p) tf.summary.scalar("loss", self.loss) self.init = [ tf.global_variables_initializer(), tf.local_variables_initializer(), tf.print("Initializing variables") ] self.saver = tf.train.Saver() self.log_dir = log_dir
def retrieve(features, retriever_beam_size, mode, params): """Do retrieval.""" tokenizer, vocab_lookup_table = bert_utils.get_tf_tokenizer( params["retriever_module_path"]) question_token_ids = tokenizer.tokenize( tf.expand_dims(features["question"], 0)) question_token_ids = tf.cast( question_token_ids.merge_dims(1, 2).to_tensor(), tf.int32) cls_token_id = vocab_lookup_table.lookup(tf.constant("[CLS]")) sep_token_id = vocab_lookup_table.lookup(tf.constant("[SEP]")) question_token_ids = tf.concat( [[[tf.cast(cls_token_id, tf.int32)]], question_token_ids, [[tf.cast(sep_token_id, tf.int32)]]], -1) retriever_module = hub.Module( params["retriever_module_path"], tags={"train"} if mode == tf.estimator.ModeKeys.TRAIN else {}, trainable=True) # [1, projection_size] question_emb = retriever_module(inputs=dict( input_ids=question_token_ids, input_mask=tf.ones_like(question_token_ids), segment_ids=tf.zeros_like(question_token_ids)), signature="projected") block_emb, searcher = scann_utils.load_scann_searcher( var_name="block_emb", checkpoint_path=os.path.join(params["retriever_module_path"], "encoded", "encoded.ckpt"), num_neighbors=retriever_beam_size) # [1, retriever_beam_size] retrieved_block_ids, _ = searcher.search_batched(question_emb) # [1, retriever_beam_size, projection_size] retrieved_block_emb = tf.gather(block_emb, retrieved_block_ids) # [retriever_beam_size] retrieved_block_ids = tf.squeeze(retrieved_block_ids) # [retriever_beam_size, projection_size] retrieved_block_emb = tf.squeeze(retrieved_block_emb) # [1, retriever_beam_size] retrieved_logits = tf.matmul(question_emb, retrieved_block_emb, transpose_b=True) # [retriever_beam_size] retrieved_logits = tf.squeeze(retrieved_logits, 0) blocks_dataset = tf.data.TFRecordDataset(params["block_records_path"], buffer_size=512 * 1024 * 1024) blocks_dataset = blocks_dataset.batch(params["num_block_records"], drop_remainder=True) blocks = tf.get_local_variable( "blocks", initializer=tf.data.experimental.get_single_element(blocks_dataset)) retrieved_blocks = tf.gather(blocks, retrieved_block_ids) return RetrieverOutputs(logits=retrieved_logits, blocks=retrieved_blocks)
def bucket_by_quantiles(len_fn, batch_size, n_buckets, hist_bounds): """Dynamically bucket a `tf.data.Dataset` based on the element's length. Keeps track of a histogram of the input elements lengths, and yields batches of examples that belong to the same length quantile. Useful in cases where you want to bucket data, but don't know what the optimal bucketing ranges should be Args: len_fn: Function mapping elements in the dataset to an integer length batch_size: Maximum size of the output batches n_buckets: Number of quantiles to break the data into hist_bounds: List of integer bounds to use when building the histograms, should cover a range so that at most a single quantile of elements are lower/higher then the bucket range. More fine-grained buckets will make the histogram more precise, but adds to the computational overhead Raises: ValueError: If `hist_bounds` or `len_fn` are invalid Returns: A function that can be used with tf.data.Dataset.apply to batch a dataset """ n_hist_binds = len(hist_bounds) if n_hist_binds < n_buckets: raise ValueError( "Requested %d buckets, but only have %d histogram bins" % (n_buckets, n_hist_binds)) if any(hist_bounds[i] >= hist_bounds[i + 1] for i in range(n_hist_binds - 1)): raise ValueError("Bins must be descending") # Need to use `use_resource = True` to make this work correctly # within tf.data.Dataset hist_counts = tf.get_local_variable("hist-counts", n_hist_binds + 1, tf.int64, tf.zeros_initializer(), use_resource=True) hist_bounds = tf.constant(hist_bounds, tf.int64) def bucket_fn(x): """Compute the element bucket and update the histogram.""" ix = len_fn(x) if ix.dtype == tf.int32: ix = tf.to_int64(ix) elif ix.dtype != tf.int64: raise ValueError("Len function returned a non-int") adds_to_bins = tf.to_int64(tf.greater(hist_bounds, ix)) # pad with a 1 for the "larger than all" bin adds_to_bins = tf.pad(adds_to_bins, [[0, 1]], constant_values=1) new_counts = tf.assign_add(hist_counts, adds_to_bins) bin_ix = n_hist_binds - tf.reduce_sum(adds_to_bins) # Computes the quantile based on the counts of the exammple's bucket bucket_ix = tf.floordiv(((n_buckets - 1) * new_counts[bin_ix]), new_counts[-1]) return bucket_ix def reduce_fn(_, x): return x.padded_batch(batch_size, dataset_ops.get_legacy_output_shapes(x)) return tf.data.experimental.group_by_window(bucket_fn, reduce_fn, batch_size)