def test(self): npa = lambda arr: np.array(arr, dtype=np.float32) correct = npa([ npa([3, 5, 7]), npa([3, 5, 7]), npa([9, 9, 9]), ]) with clean_session(): array = tf.constant( [[[1., 2., 3.], [3., 5., 7.], [100., 200., 2000.]], [[2., 4., 6.], [3., 5., 7.], [3., 5., 7.]], [[9., 9., 9.], [3., 5., 7.], [1., 2., 3.]]], dtype=tf.float32) mask = tf.constant([ [1, 1, 0], [1, 1, 1], [1, 1, 1], ], dtype=tf.float32) bm = reduce_max(SequenceBatch(array, mask)) assert_almost_equal(bm.eval(), correct, decimal=5) bad_mask = tf.constant([ [0, 0, 0], [1, 1, 1], [1, 1, 1], ], dtype=tf.float32) bm2 = reduce_mean(SequenceBatch(array, bad_mask)) with pytest.raises(InvalidArgumentError): bm2.eval()
def embed_sequences(self, embedded_sequence_batch): return reduce_max(embedded_sequence_batch)