コード例 #1
0
ファイル: rat_text2sql.py プロジェクト: kev2513/gap-text2sql
  def collate_batch(self, examples) -> Dict[str, torch.Tensor]:

    input_ids_sequences = [example["input_ids"] for example in examples]
    column_spans_sequences = [example["column_spans"] for example in examples]
    label_ids_sequences = [example["label_ids"] for example in examples]
    padded_input_ids_tensor = pad_and_tensorize_sequence(
      input_ids_sequences, padding_value=self.tokenizer.pad_token_id)
    padded_column_spans_tensor = pad_and_tensorize_sequence(
      column_spans_sequences, padding_value=(0, 1))

    example_info_list = []
    for example in examples:
      example_info_list.append(example["example_info"])
    label_ids_tensor = pad_and_tensorize_sequence(
      label_ids_sequences, padding_value=self.label_padding_id)
    return {
      "input_ids": padded_input_ids_tensor,
      "column_spans": padded_column_spans_tensor,
      "labels": label_ids_tensor,
      "example_info_list": example_info_list,
      "input_padding_id": self.tokenizer.pad_token_id,
      "label_padding_id": self.label_padding_id,
      "label_eos_id": self.label_eos_id,
      "label_bos_id": self.label_bos_id
    }
コード例 #2
0
ファイル: sql_to_text.py プロジェクト: SSamDav/gap-text2sql
    def collate_batch(self, examples):
        text_ids_sequences = [
            example["text_token_ids"] for example in examples
        ]
        sql_ids_sequences = [example["sql_token_ids"] for example in examples]

        padded_text_ids_tensor = pad_and_tensorize_sequence(
            text_ids_sequences, padding_value=self.tokenizer.pad_token_id)

        padded_sql_ids_tensor = pad_and_tensorize_sequence(
            sql_ids_sequences, padding_value=self.tokenizer.pad_token_id)

        return {
            "input_ids": padded_sql_ids_tensor,
            "labels": padded_text_ids_tensor,
            "pad_token_id": self.tokenizer.pad_token_id,
            "label_eos_id": self.label_eos_id,
            "label_bos_id": self.label_bos_id,
            "label_padding_id": self.tokenizer.pad_token_id
        }
コード例 #3
0
ファイル: tabart.py プロジェクト: kev2513/gap-text2sql
 def collate_batch(self, examples):
     input_ids_sequences = [example["input_ids"] for example in examples]
     padded_input_ids_tensor = pad_and_tensorize_sequence(
         input_ids_sequences, padding_value=self.tokenizer.pad_token_id)
     if self.task == "mlm":
         inputs, labels = self.mask_tokens(padded_input_ids_tensor.clone())
         return {
             "task": "mlm",
             "input_ids": inputs,
             "labels": padded_input_ids_tensor,
             "pad_token_id": self.tokenizer.pad_token_id,
             "label_bos_id": self.tokenizer.bos_token_id,
             "label_eos_id": self.tokenizer.eos_token_id,
             "label_padding_id": self.tokenizer.pad_token_id
         }
     elif self.task == "col_pred":
         column_labels_sequences = [
             example["column_labels"] for example in examples
         ]
         padded_label_ids_tensor = pad_and_tensorize_sequence(
             column_labels_sequences, padding_value=-100)
         column_spans_sequences = [
             example["column_spans"] for example in examples
         ]
         padded_column_spans_tensor = pad_and_tensorize_sequence(
             column_spans_sequences, padding_value=(0, 1))
         return {
             "task": "col_pred",
             "input_ids": padded_input_ids_tensor,
             "column_spans": padded_column_spans_tensor,
             "labels": padded_label_ids_tensor,
             "pad_token_id": self.tokenizer.pad_token_id
         }
     elif self.task == "mlm+col_pred":
         if random.random() < 0.6:
             inputs, labels = self.mask_tokens(
                 padded_input_ids_tensor.clone())
             return {
                 "task": "mlm",
                 "input_ids": inputs,
                 "labels": padded_input_ids_tensor,
                 "pad_token_id": self.tokenizer.pad_token_id,
                 "label_bos_id": self.tokenizer.bos_token_id,
                 "label_eos_id": self.tokenizer.eos_token_id,
                 "label_padding_id": self.tokenizer.pad_token_id
             }
         else:
             column_labels_sequences = [
                 example["column_labels"] for example in examples
             ]
             padded_label_ids_tensor = pad_and_tensorize_sequence(
                 column_labels_sequences, padding_value=-100)
             column_spans_sequences = [
                 example["column_spans"] for example in examples
             ]
             padded_column_spans_tensor = pad_and_tensorize_sequence(
                 column_spans_sequences, padding_value=(0, 1))
             return {
                 "task": "col_pred",
                 "input_ids": padded_input_ids_tensor,
                 "column_spans": padded_column_spans_tensor,
                 "labels": padded_label_ids_tensor,
                 "pad_token_id": self.tokenizer.pad_token_id
             }