示例#1
0
 def __init__(self, device, inp_dim, hidden_dim, nenc_lay, dropout):
     super().__init__()
     self.hidden_dim = hidden_dim
     self.enc_blstm = PytorchSeq2SeqWrapper(
         torch.nn.LSTM(inp_dim,
                       hidden_dim,
                       batch_first=True,
                       bidirectional=True,
                       num_layers=nenc_lay))
     self._span_encoder = BidirectionalEndpointSpanExtractor(
         self.enc_blstm.get_output_dim())
     self._dropout = torch.nn.Dropout(p=dropout)
     self.device = device
    def test_forward_raises_with_invalid_indices(self):
        sequence_tensor = torch.randn([2, 5, 8])
        extractor = BidirectionalEndpointSpanExtractor(input_dim=8)
        indices = torch.LongTensor([[[-1, 3], [7, 4]], [[0, 12], [0, -1]]])

        with pytest.raises(ValueError):
            _ = extractor(sequence_tensor, indices)
示例#3
0
class EncWord2Sent(torch.nn.Module):
    # Normal BLSTM encoder from word embedding to hidden rep
    def __init__(self, device, inp_dim, hidden_dim, nenc_lay, dropout):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.enc_blstm = PytorchSeq2SeqWrapper(
            torch.nn.LSTM(inp_dim,
                          hidden_dim,
                          batch_first=True,
                          bidirectional=True,
                          num_layers=nenc_lay))
        self._span_encoder = BidirectionalEndpointSpanExtractor(
            self.enc_blstm.get_output_dim())
        self._dropout = torch.nn.Dropout(p=dropout)
        self.device = device

    def get_output_dim(self):
        return self.enc_blstm.get_output_dim()

    def forward(self, context, context_msk):
        """

        :param context: [batch, t, inp_dim]
        :param context_msk: [batch, t] [0,1]
        :return: blstm_output [batch, t, hid_dim*2]
                avg_blstm_out [batch, hid*2]
        """
        batch_size = context.size()[0]
        blstm_output = self._dropout(self.enc_blstm(context,
                                                    context_msk))  # blstm
        context_msk[:, 0] = 1
        # context_msk = [batch, t]
        context_mask_sum = torch.sum(context_msk, dim=1) - 1
        # context_mask_sum = [batch]   [ 40, 4, 9, 0(sent len=1) , -1(sent len=0), -1]
        sum_of_mask = context_mask_sum.unsqueeze(dim=1).unsqueeze(dim=2).long()

        span_idx = torch.ones(
            (batch_size), device=self.device, dtype=torch.long) * -1
        # [ [-1] [-1] [-1] [-1] ]
        # according to sum_of_mask
        valid_bit = (context_mask_sum >= 0).long()
        span_idx = span_idx + valid_bit
        span_idx = span_idx.view((batch_size, 1, 1))
        span_idx = torch.cat([span_idx, sum_of_mask], dim=2).long()
        # Span module: (batch_size, sequence_length, embedding_size)
        #                (batch_size, num_spans, 2)
        attended_text_embeddings = self._span_encoder.forward(
            sequence_tensor=blstm_output,
            span_indices=span_idx,
            sequence_mask=context_msk,
            span_indices_mask=valid_bit.unsqueeze(1))
        # attended_text_embeddings: batch, 1, dim

        attended_text_embeddings = attended_text_embeddings.squeeze(1)
        # valid_len = context_msk.sum(dim=1).unsqueeze(1)  # batchsz,1.
        # context_msk = context_msk.unsqueeze(2)
        # msked_blstm_out = context_msk * blstm_output
        attended_text_embeddings = self._dropout(attended_text_embeddings)
        return blstm_output, attended_text_embeddings
 def test_forward_doesnt_raise_with_empty_sequence(self):
     # size: (batch_size=1, sequence_length=2, emb_dim=2)
     sequence_tensor = torch.FloatTensor([[[0., 0.], [0., 0.]]])
     # size: (batch_size=1, sequence_length=2)
     sequence_mask = torch.LongTensor([[0, 0]])
     # size: (batch_size=1, spans_count=1, 2)
     span_indices = torch.LongTensor([[[-1, -1]]])
     # size: (batch_size=1, spans_count=1)
     span_indices_mask = torch.LongTensor([[0]])
     extractor = BidirectionalEndpointSpanExtractor(
         input_dim=2, forward_combination="x,y", backward_combination="x,y")
     span_representations = extractor(sequence_tensor,
                                      span_indices,
                                      sequence_mask=sequence_mask,
                                      span_indices_mask=span_indices_mask)
     numpy.testing.assert_array_equal(
         span_representations.detach(),
         torch.FloatTensor([[[0., 0., 0., 0.]]]))
    def test_correct_sequence_elements_are_embedded(self):
        sequence_tensor = torch.randn([2, 5, 8])
        # concatentate start and end points together to form our representation
        # for both the forward and backward directions.
        extractor = BidirectionalEndpointSpanExtractor(
            input_dim=8, forward_combination="x,y", backward_combination="x,y")
        indices = torch.LongTensor([[[1, 3], [2, 4]], [[0, 2], [3, 4]]])

        span_representations = extractor(sequence_tensor, indices)

        assert list(span_representations.size()) == [2, 2, 16]
        assert extractor.get_output_dim() == 16
        assert extractor.get_input_dim() == 8

        # We just concatenated the start and end embeddings together, so
        # we can check they match the original indices if we split them apart.
        (
            forward_start_embeddings,
            forward_end_embeddings,
            backward_start_embeddings,
            backward_end_embeddings,
        ) = span_representations.split(4, -1)

        forward_sequence_tensor, backward_sequence_tensor = sequence_tensor.split(
            4, -1)

        # Forward direction => subtract 1 from start indices to make them exlusive.
        correct_forward_start_indices = torch.LongTensor([[0, 1], [-1, 2]])
        # This index should be -1, so it will be replaced with a sentinel. Here,
        # we'll set it to a value other than -1 so we can index select the indices and
        # replace it later.
        correct_forward_start_indices[1, 0] = 1

        # Forward direction => end indices are the same.
        correct_forward_end_indices = torch.LongTensor([[3, 4], [2, 4]])

        # Backward direction => start indices are exclusive, so add 1 to the end indices.
        correct_backward_start_indices = torch.LongTensor([[4, 5], [3, 5]])
        # These exclusive end indices are outside the tensor, so will be replaced with the end sentinel.
        # Here we replace them with ones so we can index select using these indices without torch
        # complaining.
        correct_backward_start_indices[0, 1] = 1
        correct_backward_start_indices[1, 1] = 1
        # Backward direction => end indices are inclusive and equal to the forward start indices.
        correct_backward_end_indices = torch.LongTensor([[1, 2], [0, 3]])

        correct_forward_start_embeddings = batched_index_select(
            forward_sequence_tensor.contiguous(),
            correct_forward_start_indices)
        # This element had sequence_tensor index of 0, so it's exclusive index is the start sentinel.
        correct_forward_start_embeddings[1, 0] = extractor._start_sentinel.data
        numpy.testing.assert_array_equal(
            forward_start_embeddings.data.numpy(),
            correct_forward_start_embeddings.data.numpy())

        correct_forward_end_embeddings = batched_index_select(
            forward_sequence_tensor.contiguous(), correct_forward_end_indices)
        numpy.testing.assert_array_equal(
            forward_end_embeddings.data.numpy(),
            correct_forward_end_embeddings.data.numpy())

        correct_backward_end_embeddings = batched_index_select(
            backward_sequence_tensor.contiguous(),
            correct_backward_end_indices)
        numpy.testing.assert_array_equal(
            backward_end_embeddings.data.numpy(),
            correct_backward_end_embeddings.data.numpy())

        correct_backward_start_embeddings = batched_index_select(
            backward_sequence_tensor.contiguous(),
            correct_backward_start_indices)
        # This element had sequence_tensor index == sequence_tensor.size(1),
        # so it's exclusive index is the end sentinel.
        correct_backward_start_embeddings[0, 1] = extractor._end_sentinel.data
        correct_backward_start_embeddings[1, 1] = extractor._end_sentinel.data
        numpy.testing.assert_array_equal(
            backward_start_embeddings.data.numpy(),
            correct_backward_start_embeddings.data.numpy())
 def test_raises_on_odd_input_dimension(self):
     with pytest.raises(ConfigurationError):
         _ = BidirectionalEndpointSpanExtractor(7)
    def test_correct_sequence_elements_are_embedded_with_a_masked_sequence(
            self):
        sequence_tensor = torch.randn([2, 5, 8])
        # concatentate start and end points together to form our representation
        # for both the forward and backward directions.
        extractor = BidirectionalEndpointSpanExtractor(
            input_dim=8, forward_combination="x,y", backward_combination="x,y")
        indices = torch.LongTensor([
            [[1, 3], [2, 4]],
            # This span has an end index at the
            # end of the padded sequence.
            [[0, 2], [0, 1]],
        ])
        sequence_mask = torch.LongTensor([[1, 1, 1, 1, 1], [1, 1, 1, 0, 0]])

        span_representations = extractor(sequence_tensor,
                                         indices,
                                         sequence_mask=sequence_mask)

        # We just concatenated the start and end embeddings together, so
        # we can check they match the original indices if we split them apart.
        (
            forward_start_embeddings,
            forward_end_embeddings,
            backward_start_embeddings,
            backward_end_embeddings,
        ) = span_representations.split(4, -1)

        forward_sequence_tensor, backward_sequence_tensor = sequence_tensor.split(
            4, -1)

        # Forward direction => subtract 1 from start indices to make them exlusive.
        correct_forward_start_indices = torch.LongTensor([[0, 1], [-1, -1]])
        # These indices should be -1, so they'll be replaced with a sentinel. Here,
        # we'll set them to a value other than -1 so we can index select the indices and
        # replace them later.
        correct_forward_start_indices[1, 0] = 1
        correct_forward_start_indices[1, 1] = 1

        # Forward direction => end indices are the same.
        correct_forward_end_indices = torch.LongTensor([[3, 4], [2, 1]])

        # Backward direction => start indices are exclusive, so add 1 to the end indices.
        correct_backward_start_indices = torch.LongTensor([[4, 5], [3, 2]])
        # These exclusive backward start indices are outside the tensor, so will be replaced
        # with the end sentinel. Here we replace them with ones so we can index select using
        # these indices without torch complaining.
        correct_backward_start_indices[0, 1] = 1

        # Backward direction => end indices are inclusive and equal to the forward start indices.
        correct_backward_end_indices = torch.LongTensor([[1, 2], [0, 0]])

        correct_forward_start_embeddings = batched_index_select(
            forward_sequence_tensor.contiguous(),
            correct_forward_start_indices)
        # This element had sequence_tensor index of 0, so it's exclusive index is the start sentinel.
        correct_forward_start_embeddings[1, 0] = extractor._start_sentinel.data
        correct_forward_start_embeddings[1, 1] = extractor._start_sentinel.data
        numpy.testing.assert_array_equal(
            forward_start_embeddings.data.numpy(),
            correct_forward_start_embeddings.data.numpy())

        correct_forward_end_embeddings = batched_index_select(
            forward_sequence_tensor.contiguous(), correct_forward_end_indices)
        numpy.testing.assert_array_equal(
            forward_end_embeddings.data.numpy(),
            correct_forward_end_embeddings.data.numpy())

        correct_backward_end_embeddings = batched_index_select(
            backward_sequence_tensor.contiguous(),
            correct_backward_end_indices)
        numpy.testing.assert_array_equal(
            backward_end_embeddings.data.numpy(),
            correct_backward_end_embeddings.data.numpy())

        correct_backward_start_embeddings = batched_index_select(
            backward_sequence_tensor.contiguous(),
            correct_backward_start_indices)
        # This element had sequence_tensor index == sequence_tensor.size(1),
        # so it's exclusive index is the end sentinel.
        correct_backward_start_embeddings[0, 1] = extractor._end_sentinel.data
        # This element has sequence_tensor index == the masked length of the batch element,
        # so it should be the end_sentinel even though it isn't greater than sequence_tensor.size(1).
        correct_backward_start_embeddings[1, 0] = extractor._end_sentinel.data

        numpy.testing.assert_array_equal(
            backward_start_embeddings.data.numpy(),
            correct_backward_start_embeddings.data.numpy())
    def test_correct_sequence_elements_are_embedded(self):
        sequence_tensor = Variable(torch.randn([2, 5, 8]))
        # concatentate start and end points together to form our representation
        # for both the forward and backward directions.
        extractor = BidirectionalEndpointSpanExtractor(input_dim=8,
                                                       forward_combination="x,y",
                                                       backward_combination="x,y")
        indices = Variable(torch.LongTensor([[[1, 3],
                                              [2, 4]],
                                             [[0, 2],
                                              [3, 4]]]))

        span_representations = extractor(sequence_tensor, indices)

        assert list(span_representations.size()) == [2, 2, 16]
        assert extractor.get_output_dim() == 16
        assert extractor.get_input_dim() == 8

        # We just concatenated the start and end embeddings together, so
        # we can check they match the original indices if we split them apart.
        (forward_start_embeddings, forward_end_embeddings,
         backward_start_embeddings, backward_end_embeddings) = span_representations.split(4, -1)

        forward_sequence_tensor, backward_sequence_tensor = sequence_tensor.split(4, -1)

        # Forward direction => subtract 1 from start indices to make them exlusive.
        correct_forward_start_indices = Variable(torch.LongTensor([[0, 1],
                                                                   [-1, 2]]))
        # This index should be -1, so it will be replaced with a sentinel. Here,
        # we'll set it to a value other than -1 so we can index select the indices and
        # replace it later.
        correct_forward_start_indices[1, 0] = 1

        # Forward direction => end indices are the same.
        correct_forward_end_indices = Variable(torch.LongTensor([[3, 4], [2, 4]]))

        # Backward direction => start indices are exclusive, so add 1 to the end indices.
        correct_backward_start_indices = Variable(torch.LongTensor([[4, 5], [3, 5]]))
        # These exclusive end indices are outside the tensor, so will be replaced with the end sentinel.
        # Here we replace them with ones so we can index select using these indices without torch
        # complaining.
        correct_backward_start_indices[0, 1] = 1
        correct_backward_start_indices[1, 1] = 1
        # Backward direction => end indices are inclusive and equal to the forward start indices.
        correct_backward_end_indices = Variable(torch.LongTensor([[1, 2], [0, 3]]))

        correct_forward_start_embeddings = batched_index_select(forward_sequence_tensor.contiguous(),
                                                                correct_forward_start_indices)
        # This element had sequence_tensor index of 0, so it's exclusive index is the start sentinel.
        correct_forward_start_embeddings[1, 0] = extractor._start_sentinel.data
        numpy.testing.assert_array_equal(forward_start_embeddings.data.numpy(),
                                         correct_forward_start_embeddings.data.numpy())

        correct_forward_end_embeddings = batched_index_select(forward_sequence_tensor.contiguous(),
                                                              correct_forward_end_indices)
        numpy.testing.assert_array_equal(forward_end_embeddings.data.numpy(),
                                         correct_forward_end_embeddings.data.numpy())

        correct_backward_end_embeddings = batched_index_select(backward_sequence_tensor.contiguous(),
                                                               correct_backward_end_indices)
        numpy.testing.assert_array_equal(backward_end_embeddings.data.numpy(),
                                         correct_backward_end_embeddings.data.numpy())

        correct_backward_start_embeddings = batched_index_select(backward_sequence_tensor.contiguous(),
                                                                 correct_backward_start_indices)
        # This element had sequence_tensor index == sequence_tensor.size(1),
        # so it's exclusive index is the end sentinel.
        correct_backward_start_embeddings[0, 1] = extractor._end_sentinel.data
        correct_backward_start_embeddings[1, 1] = extractor._end_sentinel.data
        numpy.testing.assert_array_equal(backward_start_embeddings.data.numpy(),
                                         correct_backward_start_embeddings.data.numpy())
示例#9
0
            "lang": {
                "inputs": "lang"
            }
        },
        allow_unmatched_keys=True)

    encoder: Seq2VecEncoder = PytorchSeq2SeqWrapper(
        nn.LSTM(word_embedder.get_output_dim(),
                encoder_output_dim,
                2,
                dropout=0.4,
                bidirectional=True,
                batch_first=True))

    # span_extractor: SpanExtractor = SelfAttentiveSpanExtractor(input_dim=encoder.get_output_dim())
    span_extractor: SpanExtractor = BidirectionalEndpointSpanExtractor(
        input_dim=encoder.get_output_dim())
    # probably the best solution is to make it like a factory, to add a get_decoder function that get as input
    # vocab and needed dimnsion.
    span_decoder = Topdown_Span_Parser_Factory()
    remote_parser = Basic_Remote_Parser_Factory(remote_parser_mlp_dim)

    model = UccaSpanParser(word_embedder, encoder, span_decoder,
                           span_extractor, remote_parser, UccaScores(), vocab)

    if torch.cuda.is_available():
        cuda_device = list(range(torch.cuda.device_count()))
        model = model.cuda(cuda_device[0])
    else:
        cuda_device = -1

    trainer = Trainer(model=model,