def test_get_best_span(self): # pylint: disable=protected-access span_begin_probs = Variable(torch.FloatTensor([[0.1, 0.3, 0.05, 0.3, 0.25]])).log() span_end_probs = Variable(torch.FloatTensor([[0.65, 0.05, 0.2, 0.05, 0.05]])).log() begin_end_idxs = BidirectionalAttentionFlow._get_best_span(span_begin_probs, span_end_probs) assert_almost_equal(begin_end_idxs.data.numpy(), [[0, 0]]) # When we were using exlcusive span ends, this was an edge case of the dynamic program. # We're keeping the test to make sure we get it right now, after the switch in inclusive # span end. The best answer is (1, 1). span_begin_probs = Variable(torch.FloatTensor([[0.4, 0.5, 0.1]])).log() span_end_probs = Variable(torch.FloatTensor([[0.3, 0.6, 0.1]])).log() begin_end_idxs = BidirectionalAttentionFlow._get_best_span(span_begin_probs, span_end_probs) assert_almost_equal(begin_end_idxs.data.numpy(), [[1, 1]]) # Another instance that used to be an edge case. span_begin_probs = Variable(torch.FloatTensor([[0.8, 0.1, 0.1]])).log() span_end_probs = Variable(torch.FloatTensor([[0.8, 0.1, 0.1]])).log() begin_end_idxs = BidirectionalAttentionFlow._get_best_span(span_begin_probs, span_end_probs) assert_almost_equal(begin_end_idxs.data.numpy(), [[0, 0]]) span_begin_probs = Variable(torch.FloatTensor([[0.1, 0.2, 0.05, 0.3, 0.25]])).log() span_end_probs = Variable(torch.FloatTensor([[0.1, 0.2, 0.5, 0.05, 0.15]])).log() begin_end_idxs = BidirectionalAttentionFlow._get_best_span(span_begin_probs, span_end_probs) assert_almost_equal(begin_end_idxs.data.numpy(), [[1, 2]])
def test_get_best_span(self): # pylint: disable=protected-access span_begin_probs = Variable( torch.FloatTensor([[0.1, 0.3, 0.05, 0.3, 0.25]])).log() span_end_probs = Variable( torch.FloatTensor([[0.65, 0.05, 0.2, 0.05, 0.05]])).log() begin_end_idxs = BidirectionalAttentionFlow._get_best_span( span_begin_probs, span_end_probs) assert_almost_equal(begin_end_idxs.data.numpy(), [[0, 0]]) # When we were using exlcusive span ends, this was an edge case of the dynamic program. # We're keeping the test to make sure we get it right now, after the switch in inclusive # span end. The best answer is (1, 1). span_begin_probs = Variable(torch.FloatTensor([[0.4, 0.5, 0.1]])).log() span_end_probs = Variable(torch.FloatTensor([[0.3, 0.6, 0.1]])).log() begin_end_idxs = BidirectionalAttentionFlow._get_best_span( span_begin_probs, span_end_probs) assert_almost_equal(begin_end_idxs.data.numpy(), [[1, 1]]) # Another instance that used to be an edge case. span_begin_probs = Variable(torch.FloatTensor([[0.8, 0.1, 0.1]])).log() span_end_probs = Variable(torch.FloatTensor([[0.8, 0.1, 0.1]])).log() begin_end_idxs = BidirectionalAttentionFlow._get_best_span( span_begin_probs, span_end_probs) assert_almost_equal(begin_end_idxs.data.numpy(), [[0, 0]]) span_begin_probs = Variable( torch.FloatTensor([[0.1, 0.2, 0.05, 0.3, 0.25]])).log() span_end_probs = Variable( torch.FloatTensor([[0.1, 0.2, 0.5, 0.05, 0.15]])).log() begin_end_idxs = BidirectionalAttentionFlow._get_best_span( span_begin_probs, span_end_probs) assert_almost_equal(begin_end_idxs.data.numpy(), [[1, 2]])
def test_get_best_span(self): # pylint: disable=protected-access # Note that the best span cannot be (1, 0) since even though 0.3 * 0.5 is the greatest # value, the end span index is constrained to occur after the begin span index. span_begin_probs = Variable( torch.FloatTensor([[0.1, 0.3, 0.05, 0.3, 0.25]])).log() span_end_probs = Variable( torch.FloatTensor([[0.5, 0.1, 0.2, 0.05, 0.15]])).log() begin_end_idxs = BidirectionalAttentionFlow._get_best_span( span_begin_probs, span_end_probs) assert_almost_equal(begin_end_idxs.data.numpy(), [[1, 2]]) # Testing an edge case of the dynamic program here, for the order of when you update the # best previous span position. We should not get (1, 1), because that's an empty span. span_begin_probs = Variable(torch.FloatTensor([[0.4, 0.5, 0.1]])).log() span_end_probs = Variable(torch.FloatTensor([[0.3, 0.6, 0.1]])).log() begin_end_idxs = BidirectionalAttentionFlow._get_best_span( span_begin_probs, span_end_probs) assert_almost_equal(begin_end_idxs.data.numpy(), [[0, 1]]) # Testing another edge case of the dynamic program here, where (0, 0) is the best solution # without constraints. span_begin_probs = Variable(torch.FloatTensor([[0.8, 0.1, 0.1]])).log() span_end_probs = Variable(torch.FloatTensor([[0.8, 0.1, 0.1]])).log() begin_end_idxs = BidirectionalAttentionFlow._get_best_span( span_begin_probs, span_end_probs) assert_almost_equal(begin_end_idxs.data.numpy(), [[0, 1]]) # test higher-order input # Note that the best span cannot be (1, 1) since even though 0.3 * 0.5 is the greatest # value, the end span index is constrained to occur after the begin span index. span_begin_probs = Variable( torch.FloatTensor([[0.1, 0.3, 0.05, 0.3, 0.25]])).log() span_end_probs = Variable( torch.FloatTensor([[0.1, 0.5, 0.2, 0.05, 0.15]])).log() begin_end_idxs = BidirectionalAttentionFlow._get_best_span( span_begin_probs, span_end_probs) assert_almost_equal(begin_end_idxs.data.numpy(), [[1, 2]])