def _pad_batch_records(self, batch_records):
        batch_token_ids = [record.token_ids for record in batch_records]
        batch_text_type_ids = [record.text_type_ids for record in batch_records]
        batch_position_ids = [record.position_ids for record in batch_records]
        batch_label_ids = [record.label_ids for record in batch_records]

        # padding
        padded_token_ids, input_mask, batch_seq_lens = pad_batch_data(
            batch_token_ids,
            pad_idx=self.pad_id,
            return_input_mask=True,
            return_seq_lens=True)
        padded_text_type_ids = pad_batch_data(
            batch_text_type_ids, pad_idx=self.pad_id)
        padded_position_ids = pad_batch_data(
            batch_position_ids, pad_idx=self.pad_id)
        padded_label_ids = pad_batch_data(
            batch_label_ids, pad_idx=len(self.label_map) - 1)
        padded_task_ids = np.ones_like(
            padded_token_ids, dtype="int64") * self.task_id

        return_list = [
            padded_token_ids, padded_text_type_ids, padded_position_ids,
            padded_task_ids, input_mask, padded_label_ids, batch_seq_lens
        ]
        return return_list
Пример #2
0
    def _pad_batch_records(self, batch_records, is_training):
        batch_token_ids = [record.token_ids for record in batch_records]
        batch_text_type_ids = [
            record.text_type_ids for record in batch_records
        ]
        batch_position_ids = [record.position_ids for record in batch_records]
        if is_training:
            batch_start_position = [
                record.start_position for record in batch_records
            ]
            batch_end_position = [
                record.end_position for record in batch_records
            ]
            batch_start_position = np.array(batch_start_position).astype(
                "int64").reshape([-1])
            batch_end_position = np.array(batch_end_position).astype(
                "int64").reshape([-1])

        else:
            batch_size = len(batch_token_ids)
            batch_start_position = np.zeros(shape=[batch_size], dtype="int64")
            batch_end_position = np.zeros(shape=[batch_size], dtype="int64")

        batch_unique_ids = [record.unique_id for record in batch_records]
        batch_unique_ids = np.array(batch_unique_ids).astype("int64").reshape(
            [-1])

        # padding
        padded_token_ids, input_mask = pad_batch_data(batch_token_ids,
                                                      pad_idx=self.pad_id,
                                                      return_input_mask=True)
        padded_text_type_ids = pad_batch_data(batch_text_type_ids,
                                              pad_idx=self.pad_id)
        padded_position_ids = pad_batch_data(batch_position_ids,
                                             pad_idx=self.pad_id)
        padded_task_ids = np.ones_like(padded_token_ids,
                                       dtype="int64") * self.task_id

        return_list = [
            padded_token_ids, padded_text_type_ids, padded_position_ids,
            padded_task_ids, input_mask, batch_start_position,
            batch_end_position, batch_unique_ids
        ]

        return return_list
    def _pad_batch_records(self, batch_records):
        batch_token_ids = [record.token_ids for record in batch_records]
        batch_text_type_ids = [record.text_type_ids for record in batch_records]
        batch_position_ids = [record.position_ids for record in batch_records]
        if self.phase=='train' and self.learning_strategy == 'pairwise':
            batch_token_ids_neg = [record.token_ids_neg for record in batch_records]
            batch_text_type_ids_neg = [record.text_type_ids_neg for record in batch_records]
            batch_position_ids_neg = [record.position_ids_neg for record in batch_records]

        if not self.is_inference:
            if not self.learning_strategy == 'pairwise':
                batch_labels = [record.label_id for record in batch_records]
                if self.is_classify:
                    batch_labels = np.array(batch_labels).astype("int64").reshape(
                        [-1])
                elif self.is_regression:
                    batch_labels = np.array(batch_labels).astype("float32").reshape(
                        [-1])

            if batch_records[0].qid:
                batch_qids = [record.qid for record in batch_records]
                batch_qids = np.array(batch_qids).astype("int64").reshape(
                    [-1])
            else:
                batch_qids = np.array([]).astype("int64").reshape([-1])

        # padding
        padded_token_ids, input_mask = pad_batch_data(
            batch_token_ids, pad_idx=self.pad_id, return_input_mask=True)
        padded_text_type_ids = pad_batch_data(
            batch_text_type_ids, pad_idx=self.pad_id)
        padded_position_ids = pad_batch_data(
            batch_position_ids, pad_idx=self.pad_id)
        padded_task_ids = np.ones_like(
            padded_token_ids, dtype="int64") * self.task_id

        return_list = [
            padded_token_ids, padded_text_type_ids, padded_position_ids,
            padded_task_ids, input_mask
        ]

        if self.phase=='train':
            if self.learning_strategy == 'pairwise':
                padded_token_ids_neg, input_mask_neg = pad_batch_data(
                    batch_token_ids_neg, pad_idx=self.pad_id, return_input_mask=True)
                padded_text_type_ids_neg = pad_batch_data(
                    batch_text_type_ids_neg, pad_idx=self.pad_id)
                padded_position_ids_neg = pad_batch_data(
                    batch_position_ids_neg, pad_idx=self.pad_id)
                padded_task_ids_neg = np.ones_like(
                    padded_token_ids_neg, dtype="int64") * self.task_id

                return_list += [padded_token_ids_neg, padded_text_type_ids_neg, \
                                padded_position_ids_neg, padded_task_ids_neg, input_mask_neg]

            elif self.learning_strategy == 'pointwise':
                return_list += [batch_labels]

        return return_list