def testSinkhornAutoregressiveFlowCall(self): batch_size = 3 vocab_size = 79 length = 5 units = vocab_size ** 2 inputs = np.random.randint(0, vocab_size - 1, size=(batch_size, length)) inputs = tf.one_hot(inputs, depth=vocab_size, dtype=tf.float32) layer = reversible.SinkhornAutoregressiveFlow( reversible.MADE(units, []), 1.) outputs = layer(inputs) self.evaluate(tf.global_variables_initializer()) outputs_val = self.evaluate(outputs) self.assertEqual(outputs_val.shape, (batch_size, length, vocab_size)) self.assertAllGreaterEqual(outputs_val, 0) self.assertAllLessEqual(outputs_val, vocab_size - 1)
def testDiscreteSinkhornFlowInverse(self): batch_size = 2 vocab_size = 79 length = 5 units = vocab_size ** 2 inputs = np.random.randint(0, vocab_size - 1, size=(batch_size, length)) inputs = tf.one_hot(inputs, depth=vocab_size, dtype=tf.float32) layer = reversible.SinkhornAutoregressiveFlow( reversible.MADE(units, []), 1.) rev_fwd_inputs = layer.reverse(layer(inputs)) fwd_rev_inputs = layer(layer.reverse(inputs)) self.evaluate(tf.global_variables_initializer()) inputs_val, rev_fwd_inputs_val, fwd_rev_inputs_val = self.evaluate( [inputs, rev_fwd_inputs, fwd_rev_inputs]) self.assertAllEqual(inputs_val, rev_fwd_inputs_val) self.assertAllEqual(inputs_val, fwd_rev_inputs_val)