예제 #1
0
def get_question_tensors_for_clause_tensors_batched(
        batch_size: int,
        vocab: Vocabulary,
        all_slots: Dict[str, torch.LongTensor],
        all_probs: torch.LongTensor):
    clause_slots = { k[len("clause-"):] : v for k, v in all_slots.items() if k.startswith("clause-")}
    question_slot_names = ["wh", "aux", "subj", "verb", "obj", "prep", "obj2"]
    clause_slot_names = ["subj", "aux", "verb", "obj", "prep1", "prep1-obj", "prep2", "prep2-obj", "misc", "qarg"]
    stringy_clause_slots = [
        {k : vocab.get_token_from_index(
                v[i].item(),
                namespace = get_slot_label_namespace("clause-%s" % k))
            for k, v in clause_slots.items()}
        for i in range(batch_size)
    ]
    filtered_stringy_clause_slots = []
    stringy_question_slots = []
    question_probs = []
    # for clause_slots, prob in zip(stringy_clause_slots, all_probs):
    for i in range(len(stringy_clause_slots)):
        try:
            stringy_question_slots.append(get_question_for_clause(stringy_clause_slots[i], vocab))
            filtered_stringy_clause_slots.append(stringy_clause_slots[i])
            question_probs.append(all_probs[i].item())
        except ValueError as e:
            print(str(e))

    device = torch.device("cpu")
    if torch.cuda.is_available():
        device = torch.device("cuda:%s" % torch.cuda.current_device())
    filtered_clause_slots = {
        ("clause-%s" % slot_name) : torch.tensor(
            [vocab.get_token_index(slots[slot_name], namespace = get_slot_label_namespace("clause-%s" % slot_name))
             for slots in stringy_clause_slots],
            device = device
        ).long()
        for slot_name in clause_slot_names
    }
    question_slots = {
        slot_name : torch.tensor(
            [vocab.get_token_index(slots[slot_name], namespace = get_slot_label_namespace(slot_name))
             for slots in stringy_question_slots],
            device = device
        ).long()
        for slot_name in question_slot_names
    }
    question_probs_tensor = torch.tensor(question_probs, device = device)
    return filtered_clause_slots, question_slots, question_probs_tensor
예제 #2
0
def get_question_for_clause(clause_slots, vocab: Vocabulary):
    answer_slot = clause_slots["qarg"]
    wh = answer_slot if answer_slot not in clause_slots else get_wh_for_slot_value(clause_slots[answer_slot])
    subj = clause_slots["subj"] if answer_slot != "subj" else get_gap_for_slot_value(clause_slots["subj"])
    if clause_slots["aux"] == "_" and subj != "_":
        verb = "stem"
        if clause_slots["verb"] == "past":
            aux = "did"
        elif clause_slots["verb"] == "present":
            aux = "does"
        else:
            raise ValueError("Verb slot %s cannot be split" % clause_slots["verb"])
    else:
        aux = clause_slots["aux"]
        verb = clause_slots["verb"] if clause_slots["verb"] != "present" else "presentSingular3rd"

    obj = clause_slots["obj"] if answer_slot != "obj" else get_gap_for_slot_value(clause_slots["obj"])

    if clause_slots["prep1"] != "_" and clause_slots["prep2"] != "_":
        prep = "%s %s" % (clause_slots["prep1"], clause_slots["prep2"])
        try:
            vocab.get_token_index(prep, namespace = get_slot_label_namespace("prep"))
        except KeyError:
            raise ValueError("Preposition bigram is not in vocabulary: %s" % prep)
        if clause_slots["prep1-obj"] != "_":
            # in something/someone for ...
            if answer_slot != "prep1-obj":
                raise ValueError("First preposition cannot have a placeholder object in the presence of a second preposition")
            prep1_gap = get_gap_for_slot_value(clause_slots["prep1-obj"])
            if prep1_gap != "_":
                raise ValueError("Gapped argument of first preposition in a pair must be empty; was: %s" % prep1_gap)
            # in <gap> for ...
            if clause_slots["prep2-obj"] != "_":
                # in <gap> for someone / (doing) something ...
                if clause_slots["misc"] != "_":
                    raise ValueError("When prep2 object fills last slot, misc must be empty; was: %s" % clause_slots["misc"])
                else:
                    # in <gap> for someone / (doing) something?
                    obj2 = clause_slots["prep2-obj"]
            else:
                # in <gap> for <misc>
                obj2 = clause_slots["misc"]
        else:
            # in for ...
            if clause_slots["prep2-obj"] == "_":
                # in for <misc>
                obj2 = clause_slots["misc"] if answer_slot != "misc" else get_gap_for_slot_value(clause_slots["misc"])
            else:
                # in for ?(someone / (doing) something) ...
                if answer_slot == "prep2-obj":
                    prep2_gap = get_gap_for_slot_value(clause_slots["prep2-obj"])
                    if prep2_gap != "_":
                        if clause_slots["misc"] != "_":
                            raise ValueError("When prep2 gap fills last slot, misc must be empty; was: %s" % clause_slots["misc"])
                        obj2 = prep2_gap
                    else:
                        # in for <gap> ...
                        obj2 = "_" if clause_slots["misc"] == "_" else get_gap_for_slot_value(clause_slots["misc"])
                else:
                    # in for someone / (doing) something ...
                    obj2 = get_gap_for_slot_value(clause_slots["prep2-obj"])
                    if clause_slots["misc"] != "_":
                        if answer_slot != "misc" or get_gap_for_slot_value(clause_slots["misc"]) != "_":
                            raise ValueError("When prep2 object fills last slot, misc must be empty (possibly via a gap); was: %s" % clause_slots["misc"])
    else:
        if clause_slots["prep2"] != "_":
            raise ValueError("Prep2 must only be present when prep1 is; had: %s, %s" % (clause_slots["prep1"], clause_slots["prep2"]))
        # prep1 only: in ...
        prep = clause_slots["prep1"]
        if clause_slots["prep1-obj"] == "_":
            obj2 = "_" if clause_slots["misc"] == "_" else get_gap_for_slot_value(clause_slots["misc"])
        else:
            # in ?(someone / (doing) something)
            if answer_slot == "prep1-obj":
                prep_gap = get_gap_for_slot_value(clause_slots["prep1-obj"])
                if prep_gap != "_":
                    obj2 = prep_gap
                    if clause_slots["misc"] != "_":
                        raise ValueError("When prep2 gap fills last slot, misc must be empty; was: %s" % clause_slots["misc"])
                else:
                    obj2 = "_" if clause_slots["misc"] == "_" else get_gap_for_slot_value(clause_slots["misc"])
            else:
                # in someone / (doing) something
                obj2 = clause_slots["prep1-obj"]
                if clause_slots["misc"] != "_":
                    if answer_slot != "misc" or get_gap_for_slot_value(clause_slots["misc"]) != "_":
                        raise ValueError("When prep2 object fills last slot, misc must be empty (possibly via a gap); was: %s" % clause_slots["misc"])

    res = {
        "wh": wh,
        "aux": aux,
        "subj": subj,
        "verb": verb,
        "obj": obj,
        "prep": prep,
        "obj2": obj2
    }
    # clause_slot_names = ["subj", "aux", "verb", "obj", "prep1", "prep1-obj", "prep2", "prep2-obj", "misc", "qarg"]
    # question_slot_names = ["wh", "aux", "subj", "verb", "obj", "prep", "obj2"]
    # print("===")
    # print("Clause slots:")
    # print(" ".join([clause_slots[slot_name] for slot_name in clause_slot_names]))
    # print("Question slots:")
    # print(" ".join([res[slot_name] for slot_name in question_slot_names]))
    return res
예제 #3
0
    def __init__(self,
                 vocab: Vocabulary,
                 slot_names: List[str],
                 input_dim: int,
                 slot_embedding_dim: int = 100,
                 output_dim: int = 200,
                 num_layers: int = 1,
                 recurrent_dropout: float = 0.1,
                 highway: bool = True,
                 share_rnn_cell: bool = False):
        super(SlotSequenceEncoder, self).__init__()
        self._vocab = vocab
        self._slot_names = slot_names
        self._input_dim = input_dim
        self._slot_embedding_dim = slot_embedding_dim
        self._output_dim = output_dim
        self._num_layers = num_layers
        self._recurrent_dropout = recurrent_dropout
        self._highway = highway
        self._share_rnn_cell = share_rnn_cell

        slot_embedders = []
        for i, n in enumerate(self.get_slot_names()):
            num_labels = self._vocab.get_vocab_size(
                get_slot_label_namespace(n))
            assert num_labels > 0, "Slot named %s has 0 vocab size" % (n)
            embedder = Embedding(num_labels, self._slot_embedding_dim)
            self.add_module('embedder_%s' % n, embedder)
            slot_embedders.append(embedder)
        self._slot_embedders = slot_embedders

        rnn_cells = []
        highway_nonlin = []
        highway_lin = []
        for l in range(self._num_layers):
            layer_cells = []
            layer_highway_nonlin = []
            layer_highway_lin = []
            shared_cell = None
            layer_input_size = self.get_input_dim(
            ) + self._slot_embedding_dim if l == 0 else self._output_dim
            for i, n in enumerate(self._slot_names):
                if share_rnn_cell:
                    if shared_cell is None:
                        shared_cell = LSTMCell(layer_input_size,
                                               self._output_dim)
                        self.add_module('layer_%d_cell' % l, shared_cell)
                        if highway:
                            shared_highway_nonlin = Linear(
                                layer_input_size + self._output_dim,
                                self._output_dim)
                            shared_highway_lin = Linear(layer_input_size,
                                                        self._output_dim,
                                                        bias=False)
                            self.add_module('layer_%d_highway_nonlin' % l,
                                            shared_highway_nonlin)
                            self.add_module('layer_%d_highway_lin' % l,
                                            shared_highway_lin)
                    layer_cells.append(shared_cell)
                    if highway:
                        layer_highway_nonlin.append(shared_highway_nonlin)
                        layer_highway_lin.append(shared_highway_lin)
                else:
                    cell = LSTMCell(layer_input_size, self._output_dim)
                    cell.weight_ih.data.copy_(
                        block_orthonormal_initialization(
                            layer_input_size, self._output_dim, 4).t())
                    cell.weight_hh.data.copy_(
                        block_orthonormal_initialization(
                            self._output_dim, self._output_dim, 4).t())
                    self.add_module('layer_%d_cell_%s' % (l, n), cell)
                    layer_cells.append(cell)
                    if highway:
                        nonlin = Linear(layer_input_size + self._output_dim,
                                        self._output_dim)
                        lin = Linear(layer_input_size,
                                     self._output_dim,
                                     bias=False)
                        nonlin.weight.data.copy_(
                            block_orthonormal_initialization(
                                layer_input_size + self._output_dim,
                                self._output_dim, 1).t())
                        lin.weight.data.copy_(
                            block_orthonormal_initialization(
                                layer_input_size, self._output_dim, 1).t())
                        self.add_module('layer_%d_highway_nonlin_%s' % (l, n),
                                        nonlin)
                        self.add_module('layer_%d_highway_lin_%s' % (l, n),
                                        lin)
                        layer_highway_nonlin.append(nonlin)
                        layer_highway_lin.append(lin)

            rnn_cells.append(layer_cells)
            highway_nonlin.append(layer_highway_nonlin)
            highway_lin.append(layer_highway_lin)

        self._rnn_cells = rnn_cells
        if self._highway:
            self._highway_nonlin = highway_nonlin
            self._highway_lin = highway_lin
예제 #4
0
    def __init__(self,
                 vocab: Vocabulary,
                 slot_names: List[str],
                 input_dim: int,
                 slot_hidden_dim: int = 100,
                 rnn_hidden_dim: int = 200,
                 slot_embedding_dim: int = 100,
                 num_layers: int = 1,
                 recurrent_dropout: float = 0.1,
                 highway: bool = True,
                 share_rnn_cell: bool =  False,
                 share_slot_hidden: bool = False,
                 clause_mode: bool = False): # clause_mode flag no longer used
        super(SlotSequenceGenerator, self).__init__()
        self.vocab = vocab
        self._slot_names = slot_names
        self._input_dim = input_dim
        self._slot_embedding_dim = slot_embedding_dim
        self._slot_hidden_dim = slot_hidden_dim
        self._rnn_hidden_dim = rnn_hidden_dim
        self._num_layers = num_layers
        self._recurrent_dropout = recurrent_dropout

        question_space_size = 1
        for slot_name in self._slot_names:
            num_values_for_slot = len(self.vocab.get_index_to_token_vocabulary(get_slot_label_namespace(slot_name)))
            question_space_size *= num_values_for_slot
            logger.info("%s values for slot %s" % (num_values_for_slot, slot_name))
        logger.info("Slot sequence generation space: %s possible sequences" % question_space_size)

        slot_embedders = []
        for i, n in enumerate(self.get_slot_names()[:-1]):
            num_labels = self.vocab.get_vocab_size(get_slot_label_namespace(n))
            assert num_labels > 0, "Slot named %s has 0 vocab size"%(n)
            embedder = Embedding(num_labels, self._slot_embedding_dim)
            self.add_module('embedder_%s'%n, embedder)
            slot_embedders.append(embedder)

        self._slot_embedders = slot_embedders

        self._highway = highway

        rnn_cells = []
        highway_nonlin = []
        highway_lin = []
        for l in range(self._num_layers):
            layer_cells = []
            layer_highway_nonlin = []
            layer_highway_lin = []
            shared_cell = None
            layer_input_size = self._input_dim + self._slot_embedding_dim if l == 0 else self._rnn_hidden_dim
            for i, n in enumerate(self._slot_names):
                if share_rnn_cell:
                    if shared_cell is None:
                        shared_cell = LSTMCell(layer_input_size, self._rnn_hidden_dim)
                        self.add_module('layer_%d_cell'%l, shared_cell)
                        if highway:
                            shared_highway_nonlin = Linear(layer_input_size + self._rnn_hidden_dim, self._rnn_hidden_dim)
                            shared_highway_lin = Linear(layer_input_size, self._rnn_hidden_dim, bias = False)
                            self.add_module('layer_%d_highway_nonlin'%l, shared_highway_nonlin)
                            self.add_module('layer_%d_highway_lin'%l, shared_highway_lin)
                    layer_cells.append(shared_cell)
                    if highway:
                        layer_highway_nonlin.append(shared_highway_nonlin)
                        layer_highway_lin.append(shared_highway_lin)
                else:
                    cell = LSTMCell(layer_input_size, self._rnn_hidden_dim)
                    cell.weight_ih.data.copy_(block_orthonormal_initialization(layer_input_size, self._rnn_hidden_dim, 4).t())
                    cell.weight_hh.data.copy_(block_orthonormal_initialization(self._rnn_hidden_dim, self._rnn_hidden_dim, 4).t())
                    self.add_module('layer_%d_cell_%s'%(l, n), cell)
                    layer_cells.append(cell)
                    if highway:
                        nonlin = Linear(layer_input_size + self._rnn_hidden_dim, self._rnn_hidden_dim)
                        lin = Linear(layer_input_size, self._rnn_hidden_dim, bias = False)
                        nonlin.weight.data.copy_(block_orthonormal_initialization(layer_input_size + self._rnn_hidden_dim, self._rnn_hidden_dim, 1).t())
                        lin.weight.data.copy_(block_orthonormal_initialization(layer_input_size, self._rnn_hidden_dim, 1).t())
                        self.add_module('layer_%d_highway_nonlin_%s'%(l, n), nonlin)
                        self.add_module('layer_%d_highway_lin_%s'%(l, n), lin)
                        layer_highway_nonlin.append(nonlin)
                        layer_highway_lin.append(lin)

            rnn_cells.append(layer_cells)
            highway_nonlin.append(layer_highway_nonlin)
            highway_lin.append(layer_highway_lin)

        self._rnn_cells = rnn_cells
        if highway:
            self._highway_nonlin = highway_nonlin
            self._highway_lin = highway_lin

        shared_slot_hidden = None
        slot_hiddens = []
        slot_preds = []
        slot_num_labels = []
        for i, n in enumerate(self._slot_names):
            num_labels = self.vocab.get_vocab_size(get_slot_label_namespace(n))
            slot_num_labels.append(num_labels)

            if share_slot_hidden:
                if shared_slot_hidden is None:
                    shared_slot_hidden = Linear(self._rnn_hidden_dim, self._slot_hidden_dim)
                    self.add_module('slot_hidden', shared_slot_hidden)
                slot_hiddens.append(shared_slot_hidden)
            else:
                slot_hidden = Linear(self._rnn_hidden_dim, self._slot_hidden_dim)
                slot_hiddens.append(slot_hidden)
                self.add_module('slot_hidden_%s'%n, slot_hidden)

            slot_pred = Linear(self._slot_hidden_dim, num_labels)
            slot_preds.append(slot_pred)
            self.add_module('slot_pred_%s'%n, slot_pred)

        self._slot_hiddens = slot_hiddens
        self._slot_preds = slot_preds
        self._slot_num_labels = slot_num_labels

        self._start_symbol = Parameter(torch.Tensor(self._slot_embedding_dim).normal_(0, 1))
예제 #5
0
    def beam_decode(self,
                    inputs, # shape: 1, input_dim
                    max_beam_size,
                    min_beam_probability,
                    clause_mode: bool = False):
        min_beam_log_probability = math.log(min_beam_probability)
        batch_size, input_dim = inputs.size()
        if batch_size != 1:
            raise ConfigurationError("beam_decode_single must be run with a batch size of 1.")
        if input_dim != self.get_input_dim():
            raise ConfigurationError("input dimension must match dimensionality of slot sequence model input.")

        ## metadata to recover sequences
        # slot_name -> List/Tensor of shape (beam_size) where value is index into slot's beam
        backpointers = {}
        # slot_name -> list (length <= beam_size) of indices indicating slot values
        slot_beam_labels = {}

        ## initialization for beam search loop
        init_embedding, init_mem = self._init_recurrence(inputs)
        # current state of the beam search: list of (input embedding, memory cells, log_prob), ordered by probability
        current_beam_states = [(init_embedding, init_mem, 0.)]

        for slot_index, slot_name in enumerate(self._slot_names):
            ending_clause_with_qarg = clause_mode and slot_index == (len(self._slot_names) - 1) and slot_name == "clause-qarg"
            # list of pairs (of backpointer, slot_value_index, new_embedding, new_mem, log_prob) ?
            candidate_new_beam_states = []
            for i, (emb, mem, prev_log_prob) in enumerate(current_beam_states):
                recurrence_dict = self._slot_quasi_recurrence(slot_index, slot_name, inputs, emb, mem)
                next_mem = recurrence_dict["next_mem"]
                logits = recurrence_dict["logits"].squeeze()
                log_probabilities = F.log_softmax(logits, -1)
                num_slot_values = self.vocab.get_vocab_size(get_slot_label_namespace(slot_name))
                for pred_slot_index in range(0, num_slot_values):
                    if len(log_probabilities.size()) == 0: # this only happens with slot vocab size of 1
                        log_prob = log_probabilities.item() + prev_log_prob
                    else:
                        log_prob = log_probabilities[pred_slot_index].item() + prev_log_prob
                    if slot_index < len(self._slot_names) - 1:
                        new_input_embedding = self._slot_embedders[slot_index](inputs.new([pred_slot_index]).long())
                    else:
                        new_input_embedding = None
                    # keep all expansions of the last step --- for now --- if we're on the qarg slot of a clause
                    if ending_clause_with_qarg or log_prob >= min_beam_log_probability:
                        candidate_new_beam_states.append((i, pred_slot_index, new_input_embedding, next_mem, log_prob))
            candidate_new_beam_states.sort(key = lambda t: t[4], reverse = True)
            # ditto the comment above; keeping all expansions of last step for clauses; we'll filter them later
            new_beam_states = candidate_new_beam_states[:max_beam_size] if not ending_clause_with_qarg else candidate_new_beam_states
            backpointers[slot_name] = [t[0] for t in new_beam_states]
            slot_beam_labels[slot_name] = [t[1] for t in new_beam_states]
            current_beam_states = [(t[2], t[3], t[4]) for t in new_beam_states]

        final_beam_size = len(current_beam_states)
        final_slots = {}
        for slot_name in reversed(self._slot_names):
            final_slots[slot_name] = inputs.new_zeros([final_beam_size], dtype = torch.int32)
        final_log_probs = inputs.new_zeros([final_beam_size], dtype = torch.float64)
        for beam_index in range(final_beam_size):
            final_log_probs[beam_index] = current_beam_states[beam_index][2]
            current_backpointer = beam_index
            for slot_name in reversed(self._slot_names):
                final_slots[slot_name][beam_index] = slot_beam_labels[slot_name][current_backpointer]
                current_backpointer = backpointers[slot_name][current_backpointer]

        # now if we're in clause mode, we need to filter the expanded beam
        if clause_mode:
            chosen_beam_indices = []
            for beam_index in range(final_beam_size):
                # TODO fix for abstracted slots, which have a different name. ...later. requires a nontrivial refactor
                qarg_name = self.vocab.get_token_from_index(final_slots["clause-qarg"][beam_index].item(), get_slot_label_namespace("clause-qarg"))
                qarg = "clause-%s" % qarg_name
                if qarg in self.get_slot_names():
                    # remove core arguments which are invalid
                    arg_value = self.vocab.get_token_from_index(final_slots[qarg][beam_index].item(), get_slot_label_namespace(qarg))
                    should_keep = arg_value != "_"
                else:
                    should_keep = True
                if should_keep:
                    chosen_beam_indices.append(beam_index)

            device = torch.device("cpu")
            if torch.cuda.is_available():
                device = torch.device("cuda:%s" % torch.cuda.current_device())
            chosen_beam_vector = torch.tensor(chosen_beam_indices, device = device).long()
            for slot_name in self._slot_names:
                final_slots[slot_name] = final_slots[slot_name].gather(0, chosen_beam_vector)
            final_log_probs = final_log_probs.gather(0, chosen_beam_vector)

        final_slot_indices = {
            slot_name: slot_indices.long().tolist()
            for slot_name, slot_indices in final_slots.items() }
        final_slot_labels = {
            slot_name: [self.vocab.get_token_from_index(index, get_slot_label_namespace(slot_name))
                        for index in slot_indices]
            for slot_name, slot_indices in final_slot_indices.items()
        }
        final_probs = final_log_probs.exp()
        return final_slot_indices, final_slot_labels, final_probs.tolist()
예제 #6
0
    def predict(self, inputs: JsonDict) -> JsonDict:
        qg_instances = list(self._question_model_dataset_reader.sentence_json_to_instances(inputs, verbs_only = True))
        qa_instances = list(self._question_to_span_model_dataset_reader.sentence_json_to_instances(inputs, verbs_only = True))
        if self._tan_model is not None:
            tan_instances = list(self._tan_model_dataset_reader.sentence_json_to_instances(inputs, verbs_only = True))
            tan_outputs = self._tan_model.forward_on_instances(tan_instances)
        else:
            tan_outputs = [None for _ in qg_instances]
        if self._span_to_tan_model is not None:
            span_to_tan_instances = list(self._span_to_tan_model_dataset_reader.sentence_json_to_instances(inputs, verbs_only = True))
        else:
            span_to_tan_instances = [None for _ in qg_instances]
        if self._animacy_model is not None:
            animacy_instances = list(self._animacy_model_dataset_reader.sentence_json_to_instances(inputs, verbs_only = True))
        else:
            animacy_instances = [None for _ in qg_instances]

        verb_dicts = []
        for (qg_instance, qa_instance_template, tan_output, span_to_tan_instance, animacy_instance) in zip(qg_instances, qa_instances, tan_outputs, span_to_tan_instances, animacy_instances):
            qg_instance.index_fields(self._question_model.vocab)
            qgen_input_tensors = move_to_device(
                Batch([qg_instance]).as_tensor_dict(),
                self._question_model._get_prediction_device())
            _, all_question_slots, question_probs = self._question_model.beam_decode(
                text = qgen_input_tensors["text"],
                predicate_indicator = qgen_input_tensors["predicate_indicator"],
                predicate_index = qgen_input_tensors["predicate_index"],
                max_beam_size = self._question_beam_size,
                min_beam_probability = self._question_minimum_threshold,
                clause_mode = self._clause_mode)

            verb_qa_instances = []
            question_slots_list = []
            for i in range(len(question_probs)):
                qa_instance = Instance({k: v for k, v in qa_instance_template.fields.items()})
                question_slots = {}
                for slot_name in self._question_to_span_model.get_slot_names():
                    slot_label = all_question_slots[slot_name][i]
                    question_slots[slot_name] = slot_label
                    slot_label_field = LabelField(slot_label, get_slot_label_namespace(slot_name))
                    qa_instance.add_field(slot_name, slot_label_field, self._question_to_span_model.vocab)
                question_slots_list.append(question_slots)
                verb_qa_instances.append(qa_instance)
            if len(verb_qa_instances) > 0:
                qa_outputs = self._question_to_span_model.forward_on_instances(verb_qa_instances)
                if self._animacy_model is not None or self._span_to_tan_model is not None:
                    all_spans = list(set([s for qa_output in qa_outputs for s, p in qa_output["spans"] if p >= self._span_minimum_threshold]))
                if self._animacy_model is not None:
                    animacy_instance.add_field("animacy_spans", ListField([SpanField(s.start(), s.end(), animacy_instance["text"]) for s in all_spans]), self._animacy_model.vocab)
                    animacy_output = self._animacy_model.forward_on_instance(animacy_instance)
                else:
                    animacy_output = None
                if self._span_to_tan_model is not None:
                    span_to_tan_instance.add_field("tan_spans", ListField([SpanField(s.start(), s.end(), span_to_tan_instance["text"]) for s in all_spans]))
                    span_to_tan_output = self._span_to_tan_model.forward_on_instance(span_to_tan_instance)
                else:
                    span_to_tan_output = None
            else:
                qa_outputs = []
                animacy_output = None
                span_to_tan_output = None

            qa_beam = []
            for question_slots, question_prob, qa_output in zip(question_slots_list, question_probs, qa_outputs):
                scored_spans = [(s, p) for s, p in qa_output["spans"] if p >= self._span_minimum_threshold]
                invalid_dict = {}
                if self._question_to_span_model.classifies_invalids():
                    invalid_dict["invalidProb"] = qa_output["invalid_prob"].item()
                for span, span_prob in scored_spans:
                    qa_beam.append({
                        "questionSlots": question_slots,
                        "questionProb": question_prob,
                        **invalid_dict,
                        "span": [span.start(), span.end() + 1],
                        "spanProb": span_prob
                    })
            beam = { "qa_beam": qa_beam }
            if tan_output is not None:
                beam["tans"] = [
                    (self._tan_model.vocab.get_token_from_index(i, namespace = "tan-string-labels"), p)
                    for i, p in enumerate(tan_output["probs"].tolist())
                    if p >= self._tan_minimum_threshold
                ]
            if animacy_output is not None:
                beam["animacy"] = [
                    ([s.start(), s.end() + 1], p)
                    for s, p in zip(all_spans, animacy_output["probs"].tolist())
                ]
            if span_to_tan_output is not None:
                beam["span_tans"] = [
                    ([s.start(), s.end() + 1], [
                        (self._span_to_tan_model.vocab.get_token_from_index(i, namespace = "tan-string-labels"), p)
                        for i, p in enumerate(probs)
                        if p >= self._tan_minimum_threshold])
                    for s, probs in zip(all_spans, span_to_tan_output["probs"].tolist())
                ]
            verb_dicts.append({
                "verbIndex": qg_instance["metadata"]["verb_index"],
                "verbInflectedForms": qg_instance["metadata"]["verb_inflected_forms"],
                "beam": beam
            })
        return {
            "sentenceId": inputs["sentenceId"],
            "sentenceTokens": inputs["sentenceTokens"],
            "verbs": verb_dicts
        }