def testDiscreteBipartiteFlowCall(self):
   batch_size = 3
   vocab_size = 79
   length = 5
   inputs = np.random.randint(0, vocab_size - 1, size=(batch_size, length))
   inputs = tf.one_hot(inputs, depth=vocab_size, dtype=tf.float32)
   layer = reversible.DiscreteBipartiteFlow(
       tf.identity,
       mask=tf.random_uniform([length], minval=0, maxval=2, dtype=tf.int32),
       temperature=1.)
   outputs = layer(inputs)
   self.evaluate(tf.global_variables_initializer())
   outputs_val = self.evaluate(outputs)
   self.assertEqual(outputs_val.shape, (batch_size, length, vocab_size))
   self.assertAllGreaterEqual(outputs_val, 0)
   self.assertAllLessEqual(outputs_val, vocab_size - 1)
 def testDiscreteBipartiteFlowInverse(self):
   batch_size = 2
   vocab_size = 79
   length = 5
   inputs = np.random.randint(0, vocab_size - 1, size=(batch_size, length))
   inputs = tf.one_hot(inputs, depth=vocab_size, dtype=tf.float32)
   layer = reversible.DiscreteBipartiteFlow(
       tf.identity,
       mask=tf.random_uniform([length], minval=0, maxval=2, dtype=tf.int32),
       temperature=1.)
   rev_fwd_inputs = layer.reverse(layer(inputs))
   fwd_rev_inputs = layer(layer.reverse(inputs))
   self.evaluate(tf.global_variables_initializer())
   inputs_val, rev_fwd_inputs_val, fwd_rev_inputs_val = self.evaluate(
       [inputs, rev_fwd_inputs, fwd_rev_inputs])
   self.assertAllClose(inputs_val, rev_fwd_inputs_val)
   self.assertAllClose(inputs_val, fwd_rev_inputs_val)