Example #1
0
  def output_spec(self) -> Spec:
    ret = {"tokens": lit_types.Tokens()}
    ret["tokens_" + self.config.text_a_name] = lit_types.Tokens(
        parent=self.config.text_a_name)
    if self.config.text_b_name:
      ret["tokens_" + self.config.text_b_name] = lit_types.Tokens(
          parent=self.config.text_b_name)
    if self.is_regression:
      ret["score"] = lit_types.RegressionScore(parent=self.config.label_name)
    else:
      ret["probas"] = lit_types.MulticlassPreds(
          parent=self.config.label_name,
          vocab=self.config.labels,
          null_idx=self.config.null_label_idx)
    ret["cls_emb"] = lit_types.Embeddings()
    # Average embeddings, one per layer including embeddings.
    for i in range(1 + self.model.config.num_hidden_layers):
      ret[f"layer_{i}/avg_emb"] = lit_types.Embeddings()

    ret["cls_grad"] = lit_types.Gradients(
        grad_for="cls_emb", grad_target_field_key="grad_class")

    # The input_embs_ and grad_class fields are used for Integrated Gradients.
    ret["input_embs_" + self.config.text_a_name] = lit_types.TokenEmbeddings(
        align="tokens_" + self.config.text_a_name)
    if self.config.text_b_name:
      ret["input_embs_" + self.config.text_b_name] = lit_types.TokenEmbeddings(
          align="tokens_" + self.config.text_b_name)

    # Gradients, if requested.
    if self.config.compute_grads:
      ret["grad_class"] = lit_types.CategoryLabel(required=False,
                                                  vocab=self.config.labels)
      ret["token_grad_" + self.config.text_a_name] = lit_types.TokenGradients(
          align="tokens_" + self.config.text_a_name,
          grad_for="input_embs_" + self.config.text_a_name,
          grad_target_field_key="grad_class")
      if self.config.text_b_name:
        ret["token_grad_" + self.config.text_b_name] = lit_types.TokenGradients(
            align="tokens_" + self.config.text_b_name,
            grad_for="input_embs_" + self.config.text_b_name,
            grad_target_field_key="grad_class")

    # Attention heads, one field for each layer.
    for i in range(self.model.config.num_hidden_layers):
      ret[f"layer_{i+1}/attention"] = lit_types.AttentionHeads(
          align_in="tokens", align_out="tokens")
    return ret
Example #2
0
 def input_spec(self):
     return {
         'input_embs': lit_types.TokenEmbeddings(align='tokens',
                                                 required=False),
         'segment': lit_types.TextSegment,
         'grad_class': lit_types.CategoryLabel(vocab=['0', '1'])
     }
Example #3
0
 def output_spec(self):
     return {
         'top_layer_embs':
         lit_types.TokenEmbeddings(),
         'wpm_tokens':
         lit_types.Tokens(),
         'offsets':
         lit_types.SubwordOffsets(align_in='tokens', align_out='wpm_tokens')
     }
Example #4
0
 def output_spec(self):
     return {
         'probas':
         lit_types.MulticlassPreds(parent='label',
                                   vocab=['0', '1'],
                                   null_idx=0),
         'input_embs':
         lit_types.TokenEmbeddings(align='tokens'),
     }
Example #5
0
 def output_spec(self):
   return {'probas': lit_types.MulticlassPreds(
       parent='label',
       vocab=['0', '1'],
       null_idx=0),
           'input_embs': lit_types.TokenEmbeddings(align='tokens'),
           'input_embs_grad': lit_types.TokenGradients(align='tokens',
                                                       grad_for='input_embs',
                                                       grad_target='grad_class'
                                                       ),
           'tokens': lit_types.Tokens(),
           'grad_class': lit_types.CategoryLabel(vocab=['0', '1'])
           }
Example #6
0
 def input_spec(self) -> Spec:
     ret = {}
     ret[self.config.text_a_name] = lit_types.TextSegment()
     if self.config.text_b_name:
         ret[self.config.text_b_name] = lit_types.TextSegment()
     if self.is_regression:
         ret[self.config.label_name] = lit_types.RegressionScore(
             required=False)
     else:
         ret[self.config.label_name] = lit_types.CategoryLabel(
             required=False, vocab=self.config.labels)
     # The input_embs_ and grad_class fields are used for Integrated Gradients.
     ret["input_embs_" +
         self.config.text_a_name] = lit_types.TokenEmbeddings(
             align="tokens", required=False)
     if self.config.text_b_name:
         ret["input_embs_" +
             self.config.text_b_name] = lit_types.TokenEmbeddings(
                 align="tokens", required=False)
     ret["grad_class"] = lit_types.CategoryLabel(required=False,
                                                 vocab=self.config.labels)
     return ret
Example #7
0
 def input_spec(self):
   return {'input_embs': lit_types.TokenEmbeddings(align='tokens',
                                                   required=False),
           'segment': lit_types.TextSegment()}