def _pad_and_sort_batch(self, DataLoaderBatch): batch_size = len(DataLoaderBatch) batch_split = list(zip(*DataLoaderBatch)) (sequences, targets, tokens, position_changes, character_sequences, token_characters_count, feature_set, document_ids, segment_ids) = batch_split if not self._arguments_service.evaluate and self._run_type == RunType.Test: targets = None pad_idx = self._ner_process_service.pad_idx batch_representation = BatchRepresentation( device=self._device, batch_size=batch_size, subword_sequences=sequences, character_sequences=character_sequences, subword_characters_count=token_characters_count, targets=targets, tokens=tokens, position_changes=position_changes, manual_features=feature_set, additional_information=[ (doc_id, seg_id) for doc_id, seg_id in zip(document_ids, segment_ids) ], pad_idx=pad_idx) batch_representation.sort_batch() return batch_representation
def collate_function(self, batch_input): batch_size = len(batch_input) batch_split = list(zip(*batch_input)) sequences, targets = batch_split batch_representation = BatchRepresentation(device=self._device, batch_size=batch_size, subword_sequences=sequences, targets=targets) batch_representation.sort_batch() return batch_representation
def _pad_and_sort_batch(self, DataLoaderBatch): batch_size = len(DataLoaderBatch) batch_split = list(zip(*DataLoaderBatch)) context_word_ids, targets = batch_split batch_representation = BatchRepresentation( device=self._device, batch_size=batch_size, word_sequences=context_word_ids, targets=list(targets), pad_idx=self._cbow_process_service._pad_idx) batch_representation.sort_batch() return batch_representation
def forward(self, input_batch: BatchRepresentation, debug=False, **kwargs): embedded = self._embedding_layer.forward(input_batch) x_packed = pack_padded_sequence( embedded, input_batch.character_lengths, batch_first=True) packed_output, _ = self.lstm.forward(x_packed) rnn_output, _ = pad_packed_sequence(packed_output, batch_first=True) output = self._output_layer.forward(rnn_output) if output.shape[1] < input_batch.targets.shape[1]: padded_output = torch.zeros((output.shape[0], input_batch.targets.shape[1], output.shape[2])).to( self._arguments_service.device) padded_output[:, :output.shape[1], :] = output output = padded_output elif output.shape[1] > input_batch.targets.shape[1]: padded_targets = torch.zeros((input_batch.targets.shape[0], output.shape[1]), dtype=torch.int64).to( self._arguments_service.device) padded_targets[:, :input_batch.targets.shape[1] ] = input_batch.targets input_batch._targets = padded_targets return output, input_batch.targets
def _pad_and_sort_batch(self, DataLoaderBatch): batch_size = len(DataLoaderBatch) batch_split = list(zip(*DataLoaderBatch)) sequences, _, ocr_texts, gs_texts = batch_split batch_representation = BatchRepresentation( device=self._device, batch_size=batch_size, subword_sequences=sequences, character_sequences=ocr_texts, targets=gs_texts, offset_lists=None) # TODO add offset lists batch_representation.sort_batch() return batch_representation
def collate_function(self, batch_input): batch_size = len(batch_input) batch_split = list(zip(*batch_input)) sequences, ocr_texts, gs_texts, offset_lists, tokens = batch_split batch_representation = BatchRepresentation( device=self._device, batch_size=batch_size, subword_sequences=sequences, character_sequences=ocr_texts, targets=gs_texts, tokens=tokens, offset_lists=offset_lists) batch_representation.sort_batch() return batch_representation
def forward(self, input_batch: BatchRepresentation, debug=False, **kwargs): # last hidden state of the encoder is used as the initial hidden state of the decoder encoder_hidden, encoder_final = self._encoder.forward(input_batch) src_mask = input_batch.generate_mask(input_batch.character_sequences) predictions = None if not self.training: predictions = self._decode_predictions( src_mask, encoder_hidden, encoder_final, input_batch.targets.shape[1]) outputs, _ = self._decoder.forward( input_batch.targets[:, :-1], # we remove the [EOS] tokens encoder_hidden, encoder_final, input_batch.generate_mask(input_batch.character_sequences)) gen_outputs = self._generator.forward(outputs) return gen_outputs, input_batch.targets, predictions
def collate_function(self, DataLoaderBatch): batch_split = list(zip(*DataLoaderBatch)) word_id, word = batch_split batch_representation = BatchRepresentation( device=self._arguments_service.device, batch_size=1, word_sequences=[word_id], additional_information=word[0]) return batch_representation
def _calculate_word_embeddings(self, word: str) -> List[np.array]: # word_tokens, _, _, _ = self._tokenize_service.encode_sequence( # word) word_id = self._vocabulary_service.string_to_id(word) batch = BatchRepresentation(device=self._arguments_service.device, batch_size=1, word_sequences=[[word_id]]) outputs = self._model.forward(batch) return [ output.mean(dim=1).detach().cpu().numpy() for output in outputs ]
def _greedy_decode(self, encoder_hidden, encoder_final, targets, src_mask, hidden=None): # the maximum number of steps to unroll the RNN batch_size, max_len = targets.shape # initialize decoder hidden state if hidden is None: hidden = self.init_hidden(encoder_final) input_batch = BatchRepresentation(device=self._device, batch_size=batch_size, character_sequences=targets) target_embeddings = self._embedding_layer.forward( input_batch, skip_pretrained_representation=True) # pre-compute projected encoder hidden states # (the "keys" for the attention mechanism) # this is only done for efficiency proj_key = self._attention_layer.key_layer(encoder_hidden) # here we store all intermediate hidden states and pre-output vectors decoder_states = [] pre_output_vectors = [] # unroll the decoder RNN for max_len steps for i in range(max_len): prev_embed = target_embeddings[:, i].unsqueeze(1) output, hidden, pre_output = self._internal_forward( prev_embed, encoder_hidden, src_mask, proj_key, hidden) decoder_states.append(output) pre_output_vectors.append(pre_output) decoder_states = torch.cat(decoder_states, dim=1) pre_output_vectors = torch.cat(pre_output_vectors, dim=1) return decoder_states, hidden, pre_output_vectors # [B, N, D]