Ejemplo n.º 1
0
  def testDepth(self):
    # Build very simple vocabulary
    FLAGS.vocab = os.path.join(self.get_temp_dir(), 'depth_vocab')
    with open(FLAGS.vocab, 'w') as vocab_file:
      print('X\nf', file=vocab_file)
    _, vocab_to_id = inputs.read_vocab(FLAGS.vocab)

    # Build two very deep clauses
    def deep_clause(n, clause):
      term = clause.clause.equations.add().left
      for _ in range(n):
        term.function.name = 'f'
        term = term.function.args.add()
      term.variable.name = 'X'
    examples = prover_clause_examples_pb2.ProverClauseExamples()
    deep_clause(100, examples.positives.add())
    deep_clause(200, examples.negatives.add())

    # The clause are f(f(...(X)...))
    def correct(n):
      correct = ['f', '('] * n + ['X'] + [')'] * n
      return [vocab_to_id[s] for s in correct]

    # Check that parsing works
    with self.test_session() as sess:
      _, negated_conjecture, clauses, labels = sess.run(
          inputs.random_clauses_as_sequence(examples.SerializeToString(),
                                            vocab=FLAGS.vocab))
      def decode(s):
        return np.fromstring(s, dtype=np.int32)
      self.assertAllEqual(decode(negated_conjecture), [vocab_to_id['$true']])
      self.assertAllEqual(decode(clauses[0]), correct(100))
      self.assertAllEqual(decode(clauses[1]), correct(200))
      self.assertAllEqual(labels, [True, False])
Ejemplo n.º 2
0
  def testReadVocab(self):
    # Mirrors tests in vocabulary_test.cc
    path = os.path.join(self.get_temp_dir(), 'small_vocab')
    with open(path, 'w') as vocab_file:
      for s in '7', 'X', 'Yx', 'f', 'g':
        print(s, file=vocab_file)

    def check(vocab_to_id, expect):
      expect.update({' ': 0, '*': 1, '~': 2, '|': 3, '&': 4, '(': 5, ')': 6,
                     ',': 7, '=': 8, '$false': 9, '$true': 10})
      for word, i in expect.items():
        self.assertEqual(vocab_to_id[word], i)

    # No flags
    size, vocab_to_id = inputs.read_vocab(path)
    self.assertEqual(size, 32 + 5)
    check(vocab_to_id,
          {'7': 32 + 0, 'X': 32 + 1, 'Yx': 32 + 2, 'f': 32 + 3, 'g': 32 + 4})

    # One variable
    size, vocab_to_id = inputs.read_vocab(path + ':one_variable')
    self.assertEqual(size, 32 + 4)
    check(vocab_to_id,
          {'7': 32 + 0, 'X': 32 + 1, 'Yx': 32 + 1, 'f': 32 + 2, 'g': 32 + 3})
Ejemplo n.º 3
0
def full_model(mode, hparams):
    """Make a clause search model including input pipeline.

  Args:
    mode: Either 'train' or 'eval'.
    hparams: Hyperparameters.  See default_hparams for details.

  Returns:
    logits, labels

  Raises:
    ValueError: If the model returns badly shaped tensors.
  """
    if hparams.use_averages:
        raise NotImplementedError(
            'Figure out how to eval with Polyak averaging')
    kind, model = all_models.make_model(name=hparams.model,
                                        mode=mode,
                                        hparams=hparams,
                                        vocab=FLAGS.vocab)
    batch_size = mode_batch_size(mode, hparams)

    if kind == 'sequence':
        # Read
        _, conjectures, clauses, labels = inputs.sequence_example_batch(
            mode=mode, batch_size=batch_size, shuffle=True)
        clauses = tf.reshape(clauses, [2 * batch_size, -1])
        labels = tf.reshape(labels, [2 * batch_size])

        # Embed
        vocab_size, _ = inputs.read_vocab(FLAGS.vocab)
        conjectures, clauses = model_utils.shared_embedding_layer(
            (conjectures, clauses),
            dim=hparams.embedding_size,
            size=vocab_size)

        # Classify
        conjectures = model.conjecture_embedding(conjectures)
        conjectures = tf.reshape(
            tf.tile(tf.reshape(conjectures, [batch_size, 1, -1]), [1, 2, 1]),
            [2 * batch_size, -1])
        clauses = model.axiom_embedding(clauses)
        logits = model.classifier(conjectures, clauses)
    elif kind == 'tree':
        examples = inputs.proto_batch(mode=mode, batch_size=batch_size)

        def weave(**ops):
            return clause_loom.weave_clauses(examples=examples,
                                             vocab=FLAGS.vocab,
                                             **ops)

        logits, labels = model(weave)
    elif kind == 'fast':
        examples = inputs.proto_batch(mode=mode, batch_size=batch_size)
        conjecture_sizes, conjecture_flat, clauses, labels = (
            gen_clause_ops.random_clauses_as_fast_clause(examples,
                                                         vocab=FLAGS.vocab))
        conjectures = jagged.Jagged(conjecture_sizes, conjecture_flat)
        logits = model(conjectures, clauses)

    # Done!
    return fix_logits(kind, logits), labels
Ejemplo n.º 4
0
def inference(hparams):
    """Make a clause search graph suitable for inference at proof time.

  Each described node has the correct name, for purposes of C++ lookup:

      conjecture, clauses: string, shape (?,), placeholders of serialized
          FastClause protos.
      conjecture_embeddings: float32, shape (dim,).
      logits: float32, shape (?,) output logits.
      initialize: Initialization op.

  Args:
    hparams: Hyperparameters.  See default_hparams for details.

  Returns:
    The tf.Saver object.

  Raises:
    ValueError: If the model kind is not 'tree' or 'sequence'.
  """
    if hparams.use_averages:
        raise NotImplementedError(
            'Figure out how to eval with Polyak averaging')
    kind, model = all_models.make_model(name=hparams.model,
                                        mode='eval',
                                        hparams=hparams,
                                        vocab=FLAGS.vocab)

    # Input placeholders, which will hold FastClause protos.
    conjecture = tf.placeholder(name='conjecture',
                                shape=(None, ),
                                dtype=tf.string)
    clauses = tf.placeholder(name='clauses', shape=(None, ), dtype=tf.string)

    def expand(embedding):
        """Tile the one conjecture to match clauses."""
        embeddings = tf.tile(embedding, tf.stack([tf.size(clauses), 1]))
        embeddings.set_shape([None, embedding.get_shape()[-1]])
        return embeddings

    if kind == 'sequence':
        # Embedding weights
        vocab_size, _ = inputs.read_vocab(FLAGS.vocab)
        params = model_utils.embedding_weights(dim=hparams.embedding_size,
                                               size=vocab_size)

        # Embed conjecture
        ids = gen_clause_ops.fast_clauses_as_sequence(conjecture,
                                                      conjunction=True)
        ids = tf.nn.embedding_lookup(params, ids)
        ids = ids[
            None]  # Singleton batch since many clauses are one ~conjecture
        conjecture_embedding = with_name(model.conjecture_embedding(ids),
                                         name='conjecture_embeddings')

        # Embed clauses
        ids = gen_clause_ops.fast_clauses_as_sequence(clauses)
        ids = tf.nn.embedding_lookup(params, ids)
        clause_embeddings = model.axiom_embedding(ids)

        # Classify
        logits = model.classifier(expand(conjecture_embedding),
                                  clause_embeddings)
    elif kind == 'tree':

        def weave(embed, conjecture_apply, conjecture_not, conjecture_or,
                  conjecture_and, clause_apply, clause_not, clause_or,
                  combine):
            """Weave conjecture and clauses separately, then combine."""
            # Embed conjecture, naming a concatenated version for simplicity
            parts = clause_loom.weave_fast_clauses(clauses=conjecture,
                                                   embed=embed,
                                                   apply_=conjecture_apply,
                                                   not_=conjecture_not,
                                                   or_=conjecture_or,
                                                   and_=conjecture_and,
                                                   shuffle=False)
            concat = tf.concat(parts, 1, name='conjecture_embeddings')
            splits = tf.split(concat, [p.get_shape()[1].value for p in parts],
                              axis=1)
            splits = [expand(s) for s in splits]

            # Embed clauses
            clause_embeddings = clause_loom.weave_fast_clauses(
                clauses=clauses,
                embed=embed,
                apply_=clause_apply,
                not_=clause_not,
                or_=clause_or,
                shuffle=False)

            # Combine into logits
            return combine.instantiate_batch(splits + list(clause_embeddings))

        logits, = model(weave)
    elif kind == 'fast':
        logits = model(jagged.pack([conjecture]), clauses)
    else:
        raise ValueError('Unknown kind %r' % kind)

    # Fix and name logits
    with_name(fix_logits(kind, logits), name='logits')

    # Add init op for testing purposes
    with_name(tf.global_variables_initializer(), name='initialize')

    # Add saver and init ops (the latter only for test purposes)
    return tf.train.Saver()
Ejemplo n.º 5
0
def fast_model(conjectures, clauses, vocab, hparams, mode):
    """Classify conjectures and clauses.

  Args:
    conjectures: Negated conjectures as a Jagged of serialized FastClauses.
    clauses: Clauses as serialized FastClauses.
    vocab: Path to vocabulary file.
    hparams: Hyperparameters.
    mode: Either 'train' or 'eval'.  Unused.

  Returns:
    Logits.
  """
    _ = mode  # Mode is unused
    hidden_size = hparams.hidden_size
    conv_layers = hparams.conv_layers

    # Convert all FastClauses to sequences of ids
    conjectures = inputs.fast_clauses_as_sequence_jagged(conjectures)
    clauses = inputs.fast_clauses_as_sequence_jagged(clauses)

    # Embed ids
    vocab_size, _ = inputs.read_vocab(vocab)
    params = model_utils.embedding_weights(dim=hparams.embedding_size,
                                           size=vocab_size)
    conjectures = jagged.jagged(
        conjectures.sizes, tf.nn.embedding_lookup(params, conjectures.flat))
    clauses = jagged.jagged(clauses.sizes,
                            tf.nn.embedding_lookup(params, clauses.flat))

    def bias_relu(x, bias):
        return tf.nn.relu(x + bias)

    def embed_clauses(clauses, name):
        with tf.variable_scope(name):
            filters, activations = [], []
            dim = hparams.embedding_size
            for i in range(conv_layers):
                filters.append(
                    tf.get_variable('filter%d' % i,
                                    shape=(hparams.filter_width, dim,
                                           hidden_size),
                                    initializer=layers.xavier_initializer()))
                bias = tf.get_variable('bias%d' % i,
                                       shape=(hidden_size, ),
                                       initializer=tf.constant_initializer(0))
                activations.append(functools.partial(bias_relu, bias=bias))
                dim = hidden_size
            clauses = jagged.conv1d_stack(clauses, filters, activations)
            return jagged.reduce_max(clauses)

    # Embed conjectures
    conjectures = embed_clauses(conjectures, 'conjectures')
    for _ in range(hparams.mid_layers):
        conjectures = jagged.Jagged(conjectures.sizes,
                                    layers.relu(conjectures.flat, hidden_size))
    conjectures = jagged.reduce_max(conjectures, name='conjecture_embeddings')

    # Embed clauses
    clauses = embed_clauses(clauses, 'clauses')

    # Repeat each conjecture enough times to match clauses
    expansion = tf.size(clauses) // tf.maximum(1, tf.size(conjectures))
    conjectures = tf.reshape(tf.tile(conjectures[:, None], [1, expansion, 1]),
                             [-1, hidden_size])

    # Classify
    net = tf.concat((conjectures, clauses), 1)
    net = layers.relu(net, hparams.hidden_size)
    logits = tf.squeeze(layers.linear(net, 1), [-1])
    return logits
Ejemplo n.º 6
0
def loom_model(weave, vocab, hparams, mode):
  """Tree RNN model to compute logits from from ProverClauseExamples.

  Args:
    weave: Called with the loom op keyword arguments described in
        clause_loom.weave_clauses.
    vocab: Path to vocabulary file.
    hparams: Hyperparameters.
    mode: Either 'train' or 'eval'.

  Returns:
    The results of the call to `weave`.
  """
  hidden_size = hparams.hidden_size
  embedding_size = hparams.embedding_size
  vocab_size, _ = inputs.read_vocab(vocab)
  per_layer = 2 if hparams.cell == 'lstm' else 1

  # TypeShapes
  vocab_id = clause_loom.VOCAB_ID
  logit = loom.TypeShape(tf.float32, (), 'logit')
  # Use separate embedding type shapes for separate layers to avoid confusion.
  # TODO(geoffreyi): Allow different sizes for different layers?
  embeddings = tuple(
      loom.TypeShape(tf.float32, (hidden_size,), 'embedding%d' % i)
      for i in range(hparams.layers * per_layer))

  @model_utils.as_loom_op([vocab_id], embeddings)
  def embed(ids):
    """Embed tokens and use a linear layer to get the right size."""
    values = model_utils.embedding_layer(ids, dim=embedding_size,
                                         size=vocab_size)
    if embedding_size < hidden_size:
      values = layers.linear(values, hidden_size)
    elif embedding_size > hidden_size:
      raise ValueError('embedding_size = %d > hidden_size = %d' %
                       (embedding_size, hidden_size))

    # Use relu layers to give one value per layer
    values = [values]
    for _ in range(hparams.layers - 1):
      # TODO(geoffreyi): Should these be relu or some other activation?
      values.append(layers.relu(values[-1], hidden_size))

    # If we're using LSTMs, initialize the memory cells to zero.
    if hparams.cell == 'lstm':
      memory = tf.zeros_like(values[0])
      values = [v for hidden in values for v in (memory, hidden)]
    return values

  def merge(arity, name):
    """Merge arity inputs with rule according to hparams.cell."""
    @model_utils.as_loom_op(embeddings * arity, embeddings, name=name)
    def merge(*args):
      """Process one batch of RNN inputs."""
      # shape = (arity, layers) for RNNs, (arity, layers, 2) for LSTMs,
      # where the 2 dimension is (memory, hidden).
      shape = (arity, hparams.layers) + (per_layer,) * (per_layer > 1)
      args = np.asarray(args).reshape(shape)
      below = ()  # Information flowing up from the previous layer
      outputs = []  # Results of each layer
      if hparams.cell == 'rnn-relu':
        # Vanilla RNN with relu nonlinearities
        if hparams.keep_prob != 1:
          raise ValueError('No dropout allowed for vanilla RNNs')
        for layer in range(hparams.layers):
          output = layers.relu(
              tf.concat(below + tuple(args[:, layer]), 1), hidden_size)
          outputs.append(output)
          below = output,
      elif hparams.cell == 'lstm':
        # Tree LSTM with separate forget gates per input and optional recurrent
        # dropout.  For details, see
        # 1. Improved semantic representations from tree-structured LSTM
        #    networks, http://arxiv.org/abs/1503.00075.
        # 2. Recurrent dropout without memory loss,
        #    http://arxiv.org/abs/1603.05118.
        # 3. http://colab/v2/notebook#fileId=0B2ewRpEjJXEFYjhtaExiZVBXbUk.
        memory, hidden = np.rollaxis(args, axis=-1)
        for layer in range(hparams.layers):
          raw = layers.linear(
              tf.concat(below + tuple(hidden[:, layer]), 1),
              (3 + arity) * hidden_size)
          raw = tf.split(value=raw, num_or_size_splits=3 + arity, axis=1)
          (i, j, o), fs = raw[:3], raw[3:]
          j = tf.tanh(j)
          j = layers.dropout(j, keep_prob=hparams.keep_prob,
                             is_training=mode == 'train')
          new_c = tf.add_n([tf.sigmoid(i) * j] +
                           [c * tf.sigmoid(f + hparams.forget_bias)
                            for c, f in zip(memory[:, layer], fs)])
          new_h = tf.tanh(new_c) * tf.sigmoid(o)
          outputs.extend((new_c, new_h))
          below = new_h,
      else:
        # TODO(geoffreyi): Implement tree GRU?
        raise ValueError('Unknown rnn cell type %r' % hparams.cell)
      return outputs
    return merge

  @model_utils.as_loom_op(embeddings * 2, logit)
  def classify(*args):
    """Compute logits from conjecture and clause embeddings."""
    # Use the top layer, and either cell state, hidden state, or both
    which = {'cell': 0, 'hidden': 1, 'both': (0, 1)}[hparams.output_mode]
    args = np.asarray(args).reshape(2, hparams.layers, per_layer)
    args = args[:, -1, which]
    value = layers.relu(tf.concat(tuple(args.flat), 1), hidden_size)
    return tf.squeeze(layers.linear(value, 1), [1])

  return weave(
      embed=embed,
      conjecture_apply=merge(
          2, name='conjecture/apply'),
      conjecture_not=merge(
          1, name='conjecture/not'),
      conjecture_or=merge(
          2, name='conjecture/or'),
      conjecture_and=merge(
          2, name='conjecture/and'),
      clause_apply=merge(
          2, name='clause/apply'),
      clause_not=merge(
          1, name='clause/not'),
      clause_or=merge(
          2, name='clause/or'),
      combine=classify)
Ejemplo n.º 7
0
  def testSequence(self):
    # Random generation of ProverClauseExamples
    vocab = set()

    def random_list(limit, empty, separator, f):
      count = np.random.randint(limit)
      if not count:
        return empty
      return separator.join(f() for _ in range(count))

    def new_name(prefix):
      s = '%s%d' % (prefix, len(vocab))
      vocab.add(s)
      return s

    def random_term(term, depth):
      if depth == 0 or np.random.randint(3) == 0:
        if np.random.randint(2):
          name = term.variable.name = new_name('X')
          return name
        else:
          name = term.number.value = new_name('')
          return name
      else:
        name = term.function.name = new_name('f')

        def random_arg():
          return random_term(term.function.args.add(), depth=depth - 1)

        args = random_list(2, '', ',', random_arg)
        return '%s(%s)' % (name, args) if args else name

    def random_equation(equation):
      equation.negated = np.random.randint(2)
      s = '~' * equation.negated
      s += random_term(equation.left, depth=2)
      if np.random.randint(2):
        s += '=' + random_term(equation.right, depth=1)
      return s

    def random_clause(clause):
      return random_list(4, '$false', '|',
                         lambda: random_equation(clause.clause.equations.add()))

    def random_clauses(clauses):
      return random_list(4, '$true', '&',
                         lambda: '(%s)' % random_clause(clauses.add()))

    np.random.seed(7)
    tf.set_random_seed(7)
    shards = 10
    batch_size = 2
    examples_per_shard = 6
    FLAGS.examples_train = os.path.join(self.get_temp_dir(),
                                        'examples-train@%d' % shards)
    FLAGS.examples_eval = os.path.join(self.get_temp_dir(),
                                       'examples-eval@%d' % shards)
    FLAGS.approx_proofs_per_shard = examples_per_shard
    FLAGS.input_queue_factor = 2

    # Build tfrecords of ProverClauseExamples
    key_info = {}
    mode_keys = {'train': set(), 'eval': set()}
    valid_keys = set()  # Keys with at least one positive and negative
    for mode in 'train', 'eval':
      for shard in range(shards):
        shard_path = os.path.join(
            self.get_temp_dir(),
            'examples-%s-%05d-of-%05d' % (mode, shard, shards))
        with tf.python_io.TFRecordWriter(shard_path) as writer:
          valid_count = 0
          while valid_count < examples_per_shard:
            key = 'key%d' % len(key_info)
            full_key = tf.compat.as_bytes('%s:%s' % (shard_path, key))
            examples = prover_clause_examples_pb2.ProverClauseExamples()
            examples.key = full_key
            conjecture = random_clauses(examples.cnf.negated_conjecture)
            positives = [random_clause(examples.positives.add())
                         for _ in range(np.random.randint(3))]
            negatives = [random_clause(examples.negatives.add())
                         for _ in range(np.random.randint(3))]
            writer.write(examples.SerializeToString())
            key_info[full_key] = Info(conjecture, positives, negatives)
            if positives and negatives:
              mode_keys[mode].add(full_key)
              valid_keys.add(full_key)
              valid_count += 1

    # Write vocab file
    vocab_path = os.path.join(self.get_temp_dir(), 'vocab')
    with open(vocab_path, 'w') as vocab_file:
      for s in vocab:
        print(s, file=vocab_file)
    FLAGS.vocab = vocab_path

    # Read vocabulary, and construct map from int sequence back to string
    vocab_size, vocab_to_id = inputs.read_vocab(vocab_path)
    self.assertEqual(vocab_size, len(vocab_to_id) + 32 - 11)
    id_to_vocab = {i: s for s, i in vocab_to_id.items()}

    def show_ids(ids):
      """Converts a coded clause to string, truncating and stripping."""
      return ''.join(id_to_vocab[i] for i in ids).strip()

    # Test both train and eval
    for shuffle in False, True:
      if shuffle:
        buckets = '16,32,64,128,256,512'
      else:
        # Disable bucketing so that we can verify everything is processed
        buckets = '100000'
      FLAGS.negated_conjecture_buckets = FLAGS.clause_buckets = buckets
      for mode in 'train', 'eval':
        with tf.Graph().as_default() as graph:
          keys, conjectures, clauses, labels = (inputs.sequence_example_batch(
              mode=mode, batch_size=batch_size, shuffle=shuffle))
          init_op = tf.group(tf.global_variables_initializer(),
                             tf.local_variables_initializer())
          self.assertEqual(keys.dtype, tf.string)
          self.assertEqual(conjectures.dtype, tf.int32)
          self.assertEqual(clauses.dtype, tf.int32)
          self.assertEqual(labels.dtype, tf.bool)

          # Evaluate enough times to see every key exactly twice
          with self.test_session(graph=graph) as sess:
            init_op.run()
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)
            visited = collections.defaultdict(int)
            for _ in range(len(mode_keys[mode]) // batch_size):
              batch = sess.run([keys, conjectures, clauses, labels])
              for data in batch:
                self.assertEqual(len(data), batch_size)
              for pair in batch[2:]:
                self.assertEqual(pair.shape[1], 2)
              for key, conjecture, clause_pair, label_pair in zip(*batch):
                self.assertIn(key, mode_keys[mode],
                              'mode %s, key %r, keys %r' %
                              (mode, key, mode_keys[mode]))
                visited[key] += 1
                info = key_info[key]
                self.assertEqual(info.conjecture, show_ids(conjecture))
                for clause, label in zip(clause_pair, label_pair):
                  self.assertIn(show_ids(clause),
                                info.positives if label else info.negatives)
            coord.request_stop()
            for thread in threads:
              thread.join()

        if not shuffle:
          # Verify that we visited everything exactly twice
          for key in mode_keys[mode]:
            count = visited[key]
            if count != 1:
              raise ValueError('key %s visited %d != 1 times' % (key, count))
Ejemplo n.º 8
0
  def _loomTest(self, shuffle, layers):
    # This test builds a loom that reconstructs the string representation of the
    # input.  Thus, all the "embeddings" are scalar strings, and the ops do
    # various kinds of string concatenation.  We then check that the
    # reconstructed representations match the ProverClauseExamples protos that
    # we constructed.
    np.random.seed(7)

    # Build vocabulary
    vocab_path = os.path.join(self.get_temp_dir(), 'vocab')
    with open(vocab_path, 'w') as vocab_file:
      for kind in 'f', 'X', '':
        for i in range(10):
          print('%s%d' % (kind, i), file=vocab_file)
    vocab_size, vocab_to_id = inputs.read_vocab(vocab_path)
    vocab = [''] * vocab_size
    for s, i in vocab_to_id.items():
      vocab[i] = s

    # We tag conjecture and clause ops to ensure the correct ones are used
    conjecture_tag = b'A'
    clause_tag = b'B'

    def order(s):
      if shuffle:
        return b'&'.join(sorted(b'|'.join(sorted(c.split(b'|')))
                                for c in s.split(b'&')))
      return s

    def clean_conjecture(s):
      return order(s.replace(conjecture_tag, b''))

    def clean_clause(s):
      return order(s.replace(clause_tag, b''))

    @model_utils.as_loom_op([VOCAB_ID], embeddings(layers))
    def embed(ids):
      """Turn a vocab_id back into the string it represents."""
      e0 = tf.gather(vocab, ids)
      if layers == 1:
        return e0,
      elif layers == 2:
        return e0, reverse(e0)

    @model_utils.as_loom_op(embeddings(layers) * 2, COMBINATION)
    def combine(*args):
      """Combine conjecture, clause, and label."""
      if layers == 1:
        x, y = args
        return tf.stack([x, y], axis=-1)
      elif layers == 2:
        x0, x1, y0, y1 = args
        c0 = tf.stack([x0, y0], axis=-1)
        c1 = tf.stack([x1, y1], axis=-1)
        return assert_same(c0, reverse(c1))

    # Make a loom that reconstructs the string representation of the input
    placeholder = tf.placeholder(tf.string)
    pairs, labels = clause_loom.weave_clauses(
        examples=placeholder,
        vocab=vocab_path,
        shuffle=shuffle,
        embed=embed,
        conjecture_apply=apply_op(
            conjecture_tag, layers=layers),
        conjecture_not=not_op(
            conjecture_tag, layers=layers),
        conjecture_or=binary_op(
            conjecture_tag, b'|', layers=layers),
        conjecture_and=binary_op(
            conjecture_tag, b'&', layers=layers),
        clause_apply=apply_op(
            clause_tag, layers=layers),
        clause_not=not_op(
            clause_tag, layers=layers),
        clause_or=binary_op(
            clause_tag, b'|', layers=layers),
        combine=combine)
    self.assertEqual(pairs.dtype, tf.string)
    self.assertEqual(labels.dtype, tf.bool)

    # Test it out
    with self.test_session() as sess:
      for batch_size in 0, 1, 3, 5:
        all_examples = []
        conjectures = []
        clauses = []
        for _ in range(batch_size):
          examples = prover_clause_examples_pb2.ProverClauseExamples()
          conjectures.extend(
              [order(random_clauses(examples.cnf.negated_conjecture))] * 2)
          clauses.extend([order(random_clause(examples.positives.add())),
                          order(random_clause(examples.negatives.add()))])
          all_examples.append(examples.SerializeToString())
        pairs_np, labels_np = sess.run([pairs, labels],
                                       feed_dict={placeholder: all_examples})
        self.assertEqual(pairs_np.shape, (2 * batch_size, 2))
        self.assertEqual(labels_np.shape, (2 * batch_size,))
        self.assertEqual(conjectures, [clean_conjecture(c)
                                       for c in pairs_np[:, 0]])
        self.assertEqual(clauses, [clean_clause(c) for c in pairs_np[:, 1]])
        self.assertEqual([True, False] * batch_size, list(labels_np))