コード例 #1
0
    def _pad_and_sort_batch(self, DataLoaderBatch):
        batch_size = len(DataLoaderBatch)
        batch_split = list(zip(*DataLoaderBatch))

        (sequences, targets, tokens, position_changes, character_sequences,
         token_characters_count, feature_set, document_ids,
         segment_ids) = batch_split

        if not self._arguments_service.evaluate and self._run_type == RunType.Test:
            targets = None

        pad_idx = self._ner_process_service.pad_idx
        batch_representation = BatchRepresentation(
            device=self._device,
            batch_size=batch_size,
            subword_sequences=sequences,
            character_sequences=character_sequences,
            subword_characters_count=token_characters_count,
            targets=targets,
            tokens=tokens,
            position_changes=position_changes,
            manual_features=feature_set,
            additional_information=[
                (doc_id, seg_id)
                for doc_id, seg_id in zip(document_ids, segment_ids)
            ],
            pad_idx=pad_idx)

        batch_representation.sort_batch()

        return batch_representation
コード例 #2
0
    def collate_function(self, batch_input):
        batch_size = len(batch_input)
        batch_split = list(zip(*batch_input))

        sequences, targets = batch_split

        batch_representation = BatchRepresentation(device=self._device,
                                                   batch_size=batch_size,
                                                   subword_sequences=sequences,
                                                   targets=targets)

        batch_representation.sort_batch()
        return batch_representation
コード例 #3
0
    def _pad_and_sort_batch(self, DataLoaderBatch):
        batch_size = len(DataLoaderBatch)
        batch_split = list(zip(*DataLoaderBatch))

        context_word_ids, targets = batch_split
        batch_representation = BatchRepresentation(
            device=self._device,
            batch_size=batch_size,
            word_sequences=context_word_ids,
            targets=list(targets),
            pad_idx=self._cbow_process_service._pad_idx)

        batch_representation.sort_batch()

        return batch_representation
コード例 #4
0
    def forward(self, input_batch: BatchRepresentation, debug=False, **kwargs):

        embedded = self._embedding_layer.forward(input_batch)

        x_packed = pack_padded_sequence(
            embedded, input_batch.character_lengths, batch_first=True)

        packed_output, _ = self.lstm.forward(x_packed)

        rnn_output, _ = pad_packed_sequence(packed_output, batch_first=True)

        output = self._output_layer.forward(rnn_output)

        if output.shape[1] < input_batch.targets.shape[1]:
            padded_output = torch.zeros((output.shape[0], input_batch.targets.shape[1], output.shape[2])).to(
                self._arguments_service.device)
            padded_output[:, :output.shape[1], :] = output
            output = padded_output
        elif output.shape[1] > input_batch.targets.shape[1]:
            padded_targets = torch.zeros((input_batch.targets.shape[0], output.shape[1]), dtype=torch.int64).to(
                self._arguments_service.device)
            padded_targets[:, :input_batch.targets.shape[1]
                           ] = input_batch.targets
            input_batch._targets = padded_targets

        return output, input_batch.targets
コード例 #5
0
    def _pad_and_sort_batch(self, DataLoaderBatch):
        batch_size = len(DataLoaderBatch)
        batch_split = list(zip(*DataLoaderBatch))

        sequences, _, ocr_texts, gs_texts = batch_split

        batch_representation = BatchRepresentation(
            device=self._device,
            batch_size=batch_size,
            subword_sequences=sequences,
            character_sequences=ocr_texts,
            targets=gs_texts,
            offset_lists=None)  # TODO add offset lists

        batch_representation.sort_batch()
        return batch_representation
コード例 #6
0
    def collate_function(self, batch_input):
        batch_size = len(batch_input)
        batch_split = list(zip(*batch_input))

        sequences, ocr_texts, gs_texts, offset_lists, tokens = batch_split

        batch_representation = BatchRepresentation(
            device=self._device,
            batch_size=batch_size,
            subword_sequences=sequences,
            character_sequences=ocr_texts,
            targets=gs_texts,
            tokens=tokens,
            offset_lists=offset_lists)

        batch_representation.sort_batch()
        return batch_representation
コード例 #7
0
    def forward(self, input_batch: BatchRepresentation, debug=False, **kwargs):
        # last hidden state of the encoder is used as the initial hidden state of the decoder
        encoder_hidden, encoder_final = self._encoder.forward(input_batch)
        src_mask = input_batch.generate_mask(input_batch.character_sequences)

        predictions = None
        if not self.training:
            predictions = self._decode_predictions(
                src_mask, encoder_hidden, encoder_final,
                input_batch.targets.shape[1])

        outputs, _ = self._decoder.forward(
            input_batch.targets[:, :-1],  # we remove the [EOS] tokens
            encoder_hidden,
            encoder_final,
            input_batch.generate_mask(input_batch.character_sequences))

        gen_outputs = self._generator.forward(outputs)

        return gen_outputs, input_batch.targets, predictions
コード例 #8
0
    def collate_function(self, DataLoaderBatch):
        batch_split = list(zip(*DataLoaderBatch))
        word_id, word = batch_split

        batch_representation = BatchRepresentation(
            device=self._arguments_service.device,
            batch_size=1,
            word_sequences=[word_id],
            additional_information=word[0])

        return batch_representation
コード例 #9
0
    def _calculate_word_embeddings(self, word: str) -> List[np.array]:
        # word_tokens, _, _, _ = self._tokenize_service.encode_sequence(
        #     word)

        word_id = self._vocabulary_service.string_to_id(word)

        batch = BatchRepresentation(device=self._arguments_service.device,
                                    batch_size=1,
                                    word_sequences=[[word_id]])

        outputs = self._model.forward(batch)
        return [
            output.mean(dim=1).detach().cpu().numpy() for output in outputs
        ]
コード例 #10
0
    def _greedy_decode(self,
                       encoder_hidden,
                       encoder_final,
                       targets,
                       src_mask,
                       hidden=None):
        # the maximum number of steps to unroll the RNN
        batch_size, max_len = targets.shape

        # initialize decoder hidden state
        if hidden is None:
            hidden = self.init_hidden(encoder_final)

        input_batch = BatchRepresentation(device=self._device,
                                          batch_size=batch_size,
                                          character_sequences=targets)

        target_embeddings = self._embedding_layer.forward(
            input_batch, skip_pretrained_representation=True)

        # pre-compute projected encoder hidden states
        # (the "keys" for the attention mechanism)
        # this is only done for efficiency
        proj_key = self._attention_layer.key_layer(encoder_hidden)

        # here we store all intermediate hidden states and pre-output vectors
        decoder_states = []
        pre_output_vectors = []

        # unroll the decoder RNN for max_len steps
        for i in range(max_len):
            prev_embed = target_embeddings[:, i].unsqueeze(1)
            output, hidden, pre_output = self._internal_forward(
                prev_embed, encoder_hidden, src_mask, proj_key, hidden)

            decoder_states.append(output)
            pre_output_vectors.append(pre_output)

        decoder_states = torch.cat(decoder_states, dim=1)
        pre_output_vectors = torch.cat(pre_output_vectors, dim=1)
        return decoder_states, hidden, pre_output_vectors  # [B, N, D]