Ejemplo n.º 1
0
  def get_initialized_params(self, trainable=False, scope="embedding",
                             reuse=False):
    """Returns a variable with the embeddings.

    Unlike `get_params` this does not require running a Scaffold to initialize
    the variable, however this method is not compatible with `tf.SavedModel`
    since it uses a `tf.py_func` to initialize the embedddings variable.

    Args:
      trainable: Boolean indicating whether the params should be trainable.
      scope: The name of the inner-most scope for the params.
      reuse: Boolean indicating whether to reuse params in the same scope.

    Returns:
      embedding_weights: The embedding weights.
    """

    # Hide `self._idx2emb` behind tf.py_func so its does not get serialized as
    # as part of the graph and blow up our log sizes.
    init_value = tf.py_func(lambda: self._idx2emb, [], tf.float32, False)
    init_value.set_shape([len(self._idx2emb), self._dims])

    with tf.variable_scope(scope, reuse=reuse):
      if trainable:
        embedding_weights = tf.get_variable(
            "embedding_weights", initializer=init_value)
      else:
        # Local variable so the embeddings won't get dumped into the checkpoints
        embedding_weights = tf.get_local_variable(
            "embedding_weights", initializer=init_value)
    return embedding_weights
Ejemplo n.º 2
0
def load_ragged_matrix(var_name, checkpoint_path):
  """Load sparse matrix from checkpoint."""
  with tf.gfile.Open(checkpoint_path + ".info") as f:
    num_row, num_nnz = [int(xx) for xx in f.read().split()]
  tf_data = tf.get_local_variable(
      var_name + "_data", shape=[num_nnz], dtype=tf.float32, use_resource=True)
  tf_indices = tf.get_local_variable(
      var_name + "_indices", shape=[num_nnz], dtype=tf.int64, use_resource=True)
  tf_rowsplits = tf.get_local_variable(
      var_name + "_rowsplits",
      shape=[num_row + 1],
      dtype=tf.int64,
      use_resource=True)
  init_from_checkpoint(
      checkpoint_path, target_variables=[tf_data, tf_indices, tf_rowsplits])
  return tf_data, tf_indices, tf_rowsplits
Ejemplo n.º 3
0
def load_sparse_matrix(var_name, checkpoint_path):
  """Load sparse matrix from checkpoint."""
  with tf.gfile.Open(checkpoint_path + ".info") as f:
    num_nnz = int(f.read())
  tf_data = tf.get_local_variable(
      var_name + "_data", shape=[num_nnz], dtype=tf.float32, use_resource=True)
  tf_indices = tf.get_local_variable(
      var_name + "_indices",
      shape=[num_nnz, 2],
      dtype=tf.int64,
      use_resource=True)
  tf_shape = tf.get_local_variable(
      var_name + "_shape", shape=[2], dtype=tf.int64, use_resource=True)
  init_from_checkpoint(
      checkpoint_path, target_variables=[tf_data, tf_indices, tf_shape])
  tf_sp = tf.SparseTensor(tf_indices, tf_data, tf_shape)
  return tf_sp
Ejemplo n.º 4
0
def create_mips_searcher(var_name, checkpoint_path, num_neighbors):
    """Create searcher for returning top-k closest elements."""
    tf_db = load_database(var_name, None, checkpoint_path)

    with tf.control_dependencies([tf_db.initializer]):
        mips_init_barrier = tf.constant(True)

    # Make sure DB is initialized.
    tf.get_local_variable("mips_init_barrier", initializer=mips_init_barrier)

    def _search(query):
        with tf.device("/cpu:0"):
            distance = tf.matmul(query, tf_db, transpose_b=True)
            topk_dist, topk_idx = tf.nn.top_k(distance, num_neighbors)
        topk_dist.set_shape([query.shape[0], num_neighbors])
        topk_idx.set_shape([query.shape[0], num_neighbors])
        return topk_dist, topk_idx

    return tf_db, _search
Ejemplo n.º 5
0
def load_database(var_name, shape, checkpoint_path, dtype=tf.float32):
  """Load variable from checkpoint."""
  if shape is None:
    reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
    var_to_shape_map = reader.get_variable_to_shape_map()
    shape = var_to_shape_map[var_name]
  tf_db = tf.get_local_variable(
      var_name, shape=shape, dtype=dtype, use_resource=True)
  init_from_checkpoint(checkpoint_path, target_variables=[tf_db])
  return tf_db
Ejemplo n.º 6
0
def build_logits(data_ops, embed_layer, rnn_core, output_linear, name_prefix):
    """This is the core model logic.

  Unrolls a Bayesian RNN over the given sequence.

  Args:
    data_ops: A `sequence_data.SequenceDataOps` namedtuple.
    embed_layer: A `snt.Embed` instance.
    rnn_core: A `snt.RNNCore` instance.
    output_linear: A `snt.Linear` instance.
    name_prefix: A string to use to prefix local variable names.

  Returns:
    A 3D time-major tensor representing the model's logits for a sequence of
    predictions. Shape `[time_steps, batch_size, vocab_size]`.
  """
    # Embed the input index sequence.
    embedded_input_seq = snt.BatchApply(embed_layer, name="input_embed_seq")(
        data_ops.sparse_obs)

    # Construct variables for holding the RNN state.
    initial_rnn_state = nest.map_structure(
        lambda t: tf.get_local_variable(  # pylint: disable long lambda warning
            "{}/rnn_state/{}".format(name_prefix, t.op.name),
            initializer=t),
        rnn_core.initial_state(FLAGS.batch_size))
    assign_zero_rnn_state = nest.map_structure(
        lambda x: x.assign(tf.zeros_like(x)), initial_rnn_state)
    assign_zero_rnn_state = tf.group(*nest.flatten(assign_zero_rnn_state))

    # Unroll the RNN core over the sequence.
    rnn_output_seq, rnn_final_state = tf.nn.dynamic_rnn(
        cell=rnn_core,
        inputs=embedded_input_seq,
        initial_state=initial_rnn_state,
        time_major=True)

    # Persist the RNN state for the next unroll.
    update_rnn_state = nest.map_structure(tf.assign, initial_rnn_state,
                                          rnn_final_state)
    with tf.control_dependencies(nest.flatten(update_rnn_state)):
        rnn_output_seq = tf.identity(rnn_output_seq, name="rnn_output_seq")
    output_logits = snt.BatchApply(output_linear,
                                   name="output_embed_seq")(rnn_output_seq)
    return output_logits, assign_zero_rnn_state
Ejemplo n.º 7
0
def load_scann_searcher(var_name,
                        checkpoint_path,
                        num_neighbors,
                        dimensions_per_block=2,
                        num_leaves=1000,
                        num_leaves_to_search=100,
                        training_sample_size=100000):
    """Load scann searcher from checkpoint."""
    with tf.device("/cpu:0"):
        np_db = tf.train.load_checkpoint(checkpoint_path).get_tensor(var_name)
        init_db = tf.py_func(lambda: np_db, [], tf.float32)
        init_db.set_shape(np_db.shape)
        tf_db = tf.get_local_variable(var_name, initializer=init_db)

        builder = ScannBuilder(db=tf_db,
                               num_neighbors=num_neighbors,
                               distance_measure="dot_product")
        builder = builder.tree(num_leaves=num_leaves,
                               num_leaves_to_search=num_leaves_to_search,
                               training_sample_size=training_sample_size)
        builder = builder.score_ah(dimensions_per_block=dimensions_per_block)
        searcher = builder.create_tf()
    return tf_db, searcher
Ejemplo n.º 8
0
    def __init__(self,
                 train_batch_size=4096,
                 test_chain_batch_size=4096,
                 bijector="iaf",
                 log_dir="/tmp/neutra",
                 base_learning_rate=1e-3,
                 q_base_scale=1.,
                 learning_rate_schedule=[[6000, 1e-1]]):
        target, target_spec = GetTargetSpec()
        self.target = target
        self.target_spec = target_spec
        with gin.config_scope("train"):
            train_target, train_target_spec = GetTargetSpec()
            self.train_target = train_target
            self.train_target_spec = train_target_spec

        if bijector == "rnvp":
            bijector_fn = tf.make_template("bijector",
                                           MakeRNVPBijectorFn,
                                           num_dims=self.target_spec.num_dims)
        elif bijector == "iaf":
            bijector_fn = tf.make_template("bijector",
                                           MakeIAFBijectorFn,
                                           num_dims=self.target_spec.num_dims)
        elif bijector == "affine":
            bijector_fn = tf.make_template("bijector",
                                           MakeAffineBijectorFn,
                                           num_dims=self.target_spec.num_dims)
        else:
            bijector_fn = lambda *args, **kwargs: tfb.Identity()

        self.train_bijector = bijector_fn(train=True)
        self.bijector = bijector_fn(train=False)
        if train_target_spec.bijector is not None:
            print("Using train target bijector")
            self.train_bijector = tfb.Chain(
                [train_target_spec.bijector, self.train_bijector])
        if target_spec.bijector is not None:
            print("Using target bijector")
            self.bijector = tfb.Chain([target_spec.bijector, self.bijector])

        q_base = tfd.Independent(
            tfd.Normal(loc=tf.zeros(self.target_spec.num_dims),
                       scale=q_base_scale *
                       tf.ones(self.target_spec.num_dims)), 1)
        self.q_x_train = tfd.TransformedDistribution(q_base,
                                                     self.train_bijector)
        self.q_x = tfd.TransformedDistribution(q_base, self.bijector)

        # Params
        self.train_batch_size = int(train_batch_size)
        self.test_chain_batch_size = tf.placeholder_with_default(
            test_chain_batch_size, [], "test_chain_batch_size")
        self.test_batch_size = tf.placeholder_with_default(
            16384 * 8, [], "test_batch_size")
        self.test_num_steps = tf.placeholder_with_default(
            1000, [], "test_num_steps")
        self.test_num_leapfrog_steps = tf.placeholder_with_default(
            tf.to_int32(2), [], "test_num_leapfrog_steps")
        self.test_step_size = tf.placeholder_with_default(
            0.1, [], "test_step_size")

        # Test
        self.neutra_outputs = MakeNeuTra(
            target=self.target,
            q=self.q_x,
            batch_size=self.test_chain_batch_size,
            num_steps=self.test_num_steps,
            num_leapfrog_steps=self.test_num_leapfrog_steps,
            step_size=self.test_step_size,
        )
        self.z_chain = tf.reshape(
            self.bijector.inverse(
                tf.reshape(self.neutra_outputs.x_chain,
                           [-1, self.target_spec.num_dims])),
            tf.shape(self.neutra_outputs.x_chain))
        self.target_samples = self.target.sample(self.test_batch_size)
        self.target_z = self.bijector.inverse(self.target_samples)
        self.q_samples = self.q_x.sample(self.test_batch_size)

        self.target_cov = utils.Covariance(self.target_samples)
        self.target_eigvals, self.target_eigvecs = tf.linalg.eigh(
            self.target_cov)

        self.cached_target_eigvals = tf.get_local_variable(
            "cached_target_eigvals",
            self.target_eigvals.shape,
            initializer=tf.zeros_initializer())
        self.cached_target_eigvecs = tf.get_local_variable(
            "cached_target_eigvecs",
            self.target_eigvecs.shape,
            initializer=tf.zeros_initializer())
        self.cached_target_stats_update_op = [
            self.cached_target_eigvals.assign(self.target_eigvals),
            self.cached_target_eigvecs.assign(self.target_eigvecs),
            tf.print("Assigning target stats")
        ]

        def variance(x):
            x -= tf.reduce_mean(x, 0, keep_dims=True)
            x = tf.square(x)
            return x

        def rotated_variance(x):
            x2 = tf.reshape(x, [-1, self.target_spec.num_dims])
            x2 -= tf.reduce_mean(x2, 0, keep_dims=True)
            x2 = tf.matmul(x2, self.cached_target_eigvecs)
            x2 = tf.square(x2)
            return tf.reshape(x2, tf.shape(x))

        functions = [
            ("mean", tf.identity),
            #        ("var", variance),
            ("square", tf.square),
            #        ("rot_square", rot_square),
            #        ("rot_var", rotated_variance),
        ]

        self.cached_target_mean = {}
        self.cached_target_mean_update_op = [
            tf.print("Assigning target means.")
        ]
        self.neutra_stats = {}
        self.q_stats = {}

        for name, f in functions:
            target_mean = tf.reduce_mean(f(self.target_samples), 0)
            cached_target_mean = tf.get_local_variable(name + "_cached_mean",
                                                       target_mean.shape)
            if self.target_spec.stats is not None:
                self.cached_target_mean_update_op.append(
                    cached_target_mean.assign(self.target_spec.stats[name]))
            else:
                self.cached_target_mean_update_op.append(
                    cached_target_mean.assign(target_mean))

            self.cached_target_mean[name] = cached_target_mean
            self.q_stats[name] = ComputeQStats(f(self.q_samples),
                                               cached_target_mean)
            self.neutra_stats[name] = ComputeChainStats(
                f(self.neutra_outputs.x_chain), cached_target_mean,
                self.test_num_leapfrog_steps)

        # Training
        self.train_q_samples = self.q_x_train.sample(self.train_batch_size)
        self.train_log_q_x = self.q_x_train.log_prob(self.train_q_samples)
        self.kl_q_p = tf.reduce_mean(
            self.train_log_q_x - self.target.log_prob(self.train_q_samples))

        loss = self.kl_q_p
        reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
        if reg_losses:
            tf.logging.info("Regularizing.")
            loss += tf.add_n(reg_losses)
        self.loss = tf.check_numerics(loss, "Loss has NaNs")

        self.global_step = tf.train.get_or_create_global_step()
        steps, factors = list(zip(*learning_rate_schedule))
        learning_rate = base_learning_rate * tf.train.piecewise_constant(
            self.global_step, steps, [1.0] + list(factors))

        opt = tf.train.AdamOptimizer(learning_rate=learning_rate)
        self.train_op = opt.minimize(self.loss, global_step=self.global_step)

        tf.summary.scalar("kl_q_p", self.kl_q_p)
        tf.summary.scalar("loss", self.loss)

        self.init = [
            tf.global_variables_initializer(),
            tf.local_variables_initializer(),
            tf.print("Initializing variables")
        ]

        self.saver = tf.train.Saver()
        self.log_dir = log_dir
Ejemplo n.º 9
0
def retrieve(features, retriever_beam_size, mode, params):
    """Do retrieval."""
    tokenizer, vocab_lookup_table = bert_utils.get_tf_tokenizer(
        params["retriever_module_path"])

    question_token_ids = tokenizer.tokenize(
        tf.expand_dims(features["question"], 0))
    question_token_ids = tf.cast(
        question_token_ids.merge_dims(1, 2).to_tensor(), tf.int32)
    cls_token_id = vocab_lookup_table.lookup(tf.constant("[CLS]"))
    sep_token_id = vocab_lookup_table.lookup(tf.constant("[SEP]"))
    question_token_ids = tf.concat(
        [[[tf.cast(cls_token_id, tf.int32)]], question_token_ids,
         [[tf.cast(sep_token_id, tf.int32)]]], -1)

    retriever_module = hub.Module(
        params["retriever_module_path"],
        tags={"train"} if mode == tf.estimator.ModeKeys.TRAIN else {},
        trainable=True)

    # [1, projection_size]
    question_emb = retriever_module(inputs=dict(
        input_ids=question_token_ids,
        input_mask=tf.ones_like(question_token_ids),
        segment_ids=tf.zeros_like(question_token_ids)),
                                    signature="projected")

    block_emb, searcher = scann_utils.load_scann_searcher(
        var_name="block_emb",
        checkpoint_path=os.path.join(params["retriever_module_path"],
                                     "encoded", "encoded.ckpt"),
        num_neighbors=retriever_beam_size)

    # [1, retriever_beam_size]
    retrieved_block_ids, _ = searcher.search_batched(question_emb)

    # [1, retriever_beam_size, projection_size]
    retrieved_block_emb = tf.gather(block_emb, retrieved_block_ids)

    # [retriever_beam_size]
    retrieved_block_ids = tf.squeeze(retrieved_block_ids)

    # [retriever_beam_size, projection_size]
    retrieved_block_emb = tf.squeeze(retrieved_block_emb)

    # [1, retriever_beam_size]
    retrieved_logits = tf.matmul(question_emb,
                                 retrieved_block_emb,
                                 transpose_b=True)

    # [retriever_beam_size]
    retrieved_logits = tf.squeeze(retrieved_logits, 0)

    blocks_dataset = tf.data.TFRecordDataset(params["block_records_path"],
                                             buffer_size=512 * 1024 * 1024)
    blocks_dataset = blocks_dataset.batch(params["num_block_records"],
                                          drop_remainder=True)
    blocks = tf.get_local_variable(
        "blocks",
        initializer=tf.data.experimental.get_single_element(blocks_dataset))
    retrieved_blocks = tf.gather(blocks, retrieved_block_ids)
    return RetrieverOutputs(logits=retrieved_logits, blocks=retrieved_blocks)
Ejemplo n.º 10
0
def bucket_by_quantiles(len_fn, batch_size, n_buckets, hist_bounds):
    """Dynamically bucket a `tf.data.Dataset` based on the element's length.

  Keeps track of a histogram of the input elements lengths, and yields batches
  of examples that belong to the same length quantile.

  Useful in cases where you want to bucket data, but don't know what the
  optimal bucketing ranges should be

  Args:
    len_fn: Function mapping elements in the dataset to an integer length
    batch_size: Maximum size of the output batches
    n_buckets: Number of quantiles to break the data into
    hist_bounds: List of integer bounds to use when building the histograms,
      should cover a range so that at most a single quantile of elements are
      lower/higher then the bucket range. More fine-grained buckets will make
      the histogram more precise, but adds to the computational overhead

  Raises:
    ValueError: If `hist_bounds` or `len_fn` are invalid
  Returns:
    A function that can be used with tf.data.Dataset.apply to batch a dataset
  """
    n_hist_binds = len(hist_bounds)

    if n_hist_binds < n_buckets:
        raise ValueError(
            "Requested %d buckets, but only have %d histogram bins" %
            (n_buckets, n_hist_binds))
    if any(hist_bounds[i] >= hist_bounds[i + 1]
           for i in range(n_hist_binds - 1)):
        raise ValueError("Bins must be descending")

    # Need to use `use_resource = True` to make this work correctly
    # within tf.data.Dataset
    hist_counts = tf.get_local_variable("hist-counts",
                                        n_hist_binds + 1,
                                        tf.int64,
                                        tf.zeros_initializer(),
                                        use_resource=True)
    hist_bounds = tf.constant(hist_bounds, tf.int64)

    def bucket_fn(x):
        """Compute the element bucket and update the histogram."""
        ix = len_fn(x)
        if ix.dtype == tf.int32:
            ix = tf.to_int64(ix)
        elif ix.dtype != tf.int64:
            raise ValueError("Len function returned a non-int")

        adds_to_bins = tf.to_int64(tf.greater(hist_bounds, ix))
        # pad with a 1 for the "larger than all" bin
        adds_to_bins = tf.pad(adds_to_bins, [[0, 1]], constant_values=1)
        new_counts = tf.assign_add(hist_counts, adds_to_bins)
        bin_ix = n_hist_binds - tf.reduce_sum(adds_to_bins)
        # Computes the quantile based on the counts of the exammple's bucket
        bucket_ix = tf.floordiv(((n_buckets - 1) * new_counts[bin_ix]),
                                new_counts[-1])
        return bucket_ix

    def reduce_fn(_, x):
        return x.padded_batch(batch_size,
                              dataset_ops.get_legacy_output_shapes(x))

    return tf.data.experimental.group_by_window(bucket_fn, reduce_fn,
                                                batch_size)