예제 #1
0
    def inference(self, dev_data):
        self.eval()
        batch_size = min(len(dev_data), 16)

        outputs, output_spans = [], []
        for batch_start_id in tqdm(range(0, len(dev_data), batch_size)):
            mini_batch = dev_data[batch_start_id: batch_start_id + batch_size]
            _, text_masks = ops.pad_batch([exp.text_ids for exp in mini_batch], bu.pad_id)
            encoder_input_ids = ops.pad_batch([exp.ptr_input_ids for exp in mini_batch], bu.pad_id)
            # [batch_size, 2, encoder_seq_len]
            output, span_extract_output = self.forward(encoder_input_ids, text_masks)
            outputs.append(output)
            encoder_seq_len = span_extract_output.size(2)
            # [batch_size, encoder_seq_len]
            start_logit = span_extract_output[:, 0, :]
            end_logit = span_extract_output[:, 1, :]
            # [batch_size, encoder_seq_len, encoder_seq_len]
            span_logit = start_logit.unsqueeze(2) + end_logit.unsqueeze(1)
            valid_span_pos = ops.ones_var_cuda([len(span_logit), encoder_seq_len, encoder_seq_len]).triu()
            span_logit = span_logit - (1 - valid_span_pos) * ops.HUGE_INT

            for i in range(len(mini_batch)):
                span_pos = span_logit[i].argmax()
                start = int(span_pos / encoder_seq_len)
                end = int(span_pos % encoder_seq_len)
                output_spans.append((start, end))
        return torch.cat(outputs), output_spans
    def inference(self, dev_data):
        self.eval()
        batch_size = 32

        output_spans = []
        for batch_start_id in tqdm(range(0, len(dev_data), batch_size)):
            mini_batch = dev_data[batch_start_id:batch_start_id + batch_size]
            _, text_masks = ops.pad_batch([exp.text_ids for exp in mini_batch],
                                          bu.pad_id)
            encoder_input_ids = ops.pad_batch(
                [exp.ptr_input_ids for exp in mini_batch], bu.pad_id)
            # [batch_size, 2, encoder_seq_len]
            output = self.forward(encoder_input_ids, text_masks)
            encoder_seq_len = output.size(2)
            # [batch_size, encoder_seq_len]
            start_logit = output[:, 0, :]
            end_logit = output[:, 1, :]
            # [batch_size, encoder_seq_len, encoder_seq_len]
            span_logit = start_logit.unsqueeze(2) + end_logit.unsqueeze(1)
            valid_span_pos = ops.ones_var_cuda(
                [batch_size, encoder_seq_len, encoder_seq_len]).triu()
            span_logit = span_logit - (1 - valid_span_pos) * ops.HUGE_INT

            for i in range(len(mini_batch)):
                span_pos = span_logit[i].argmax()
                start = int(span_pos / encoder_seq_len)
                end = int(span_pos % encoder_seq_len)
                output_spans.append((start, end))
                # print(start, end)
                # print(mini_batch[i].text)
                # confusion_span_size = end - start + 1
                # if start > 0 and confusion_span_size < 5:
                #     print(mini_batch[i].text_tokens[start-1:end])
                #     print()

        return output_spans
예제 #3
0
def beam_search(alpha, model, decoder, decoder_embeddings, num_steps, beam_size, encoder_final_hidden,
                encoder_hiddens=None, encoder_masks=None, encoder_ptr_value_ids=None, constant_hiddens=None,
                constant_hidden_masks=None, schema_hiddens=None, schema_hidden_masks=None, table_masks=None,
                schema_memory_masks=None, db_scope=None, no_from=False, start_embedded=None):

    def compute_memory_inputs(constant_seq_len, schema_seq_len, table_masks):
        """
        :return memory_inputs: [batch_size, memory_size]
            || 5 || 6 || 7 ||
            memory_inputs[i]: [value_id, value_id, ..., table_id, field_id, ..., table_id, ...]
        """
        memory_size = constant_seq_len + schema_seq_len
        memory_max_size = int(max(memory_size))
        memory_input_constant_masks = (ops.batch_arange_cuda(batch_size, memory_max_size) <
                                       constant_seq_len.unsqueeze(1)).long()
        memory_input_schema_masks = 1 - memory_input_constant_masks
        memory_inputs = memory_input_constant_masks * decoder.vocab.value_id + \
                        memory_input_schema_masks * decoder.vocab.field_id
        memory_inputs = memory_inputs.view(-1)
        table_idx1, table_idx2 = torch.nonzero(table_masks).split(1, dim=1)
        table_pos_ = table_idx1 * memory_max_size + constant_seq_len[table_idx1] + table_idx2
        memory_inputs[table_pos_] = decoder.vocab.table_id
        memory_inputs = memory_inputs.view(batch_size, memory_max_size)

        memory_input_table_masks = (memory_inputs == decoder.vocab.table_id).long()
        memory_input_field_masks = ops.int_ones_var_cuda(memory_input_table_masks.size()) if no_from else \
                                   ops.int_zeros_var_cuda(memory_input_table_masks.size())
        return memory_inputs, memory_input_table_masks, memory_input_field_masks, memory_input_constant_masks

    def get_vocab_cat_masks():
        v_clause_mask = ops.int_var_cuda(decoder.vocab.clause_mask)
        v_op_mask = ops.int_var_cuda(decoder.vocab.op_mask)
        v_join_mask = ops.int_var_cuda(decoder.vocab.join_mask)
        v_others_mask = 1 - v_clause_mask - v_op_mask - v_join_mask
        return v_clause_mask, v_op_mask, v_join_mask, v_others_mask

    def offset_hidden(h, beam_offset):
        if isinstance(h, tuple):
            return torch.index_select(h[0], 1, beam_offset), torch.index_select(h[1], 1, beam_offset)
        else:
            return torch.index_select(h, 1, beam_offset)

    def update_beam_search_history(history, state, beam_offset, offset_dim, seq_dim, offset_state=False):
        if history is None:
            return state
        else:
            history = torch.index_select(history, offset_dim, beam_offset)
            if offset_state:
                state = torch.index_select(state, offset_dim, beam_offset)
            return torch.cat([history, state], seq_dim)

    if encoder_hiddens is None:
        batch_size = constant_hiddens.size(0)
    else:
        batch_size = encoder_hiddens.size(0)
    full_size = batch_size * beam_size

    start_id = decoder.vocab.start_id
    eos_id = decoder.vocab.eos_id
    digit_0_id = decoder.vocab.to_idx('0')
    digit_1_id = decoder.vocab.to_idx('1')
    digit_2_id = decoder.vocab.to_idx('2')
    digit_3_id = decoder.vocab.to_idx('3')
    digit_4_id = decoder.vocab.to_idx('4')
    digit_5_id = decoder.vocab.to_idx('5')
    digit_6_id = decoder.vocab.to_idx('6')
    digit_7_id = decoder.vocab.to_idx('7')
    digit_8_id = decoder.vocab.to_idx('8')
    digit_9_id = decoder.vocab.to_idx('9')
    digit_10_id = decoder.vocab.to_idx('10')
    digit_11_id = decoder.vocab.to_idx('11')
    digit_12_id = decoder.vocab.to_idx('12')
    digit_s0_id = decoder.vocab.to_idx('##0')
    digit_s1_id = decoder.vocab.to_idx('##1')
    digit_s2_id = decoder.vocab.to_idx('##2')
    digit_s3_id = decoder.vocab.to_idx('##3')
    digit_s4_id = decoder.vocab.to_idx('##4')
    digit_s5_id = decoder.vocab.to_idx('##5')
    seen_eos = ops.byte_zeros_var_cuda([full_size, 1])
    seq_len = 0

    if type(encoder_final_hidden) is tuple:
        assert(len(encoder_final_hidden) == 2)
        hidden = (ops.tile_along_beam(encoder_final_hidden[0], beam_size, dim=1),
                  ops.tile_along_beam(encoder_final_hidden[1], beam_size, dim=1))
    elif encoder_final_hidden is not None:
        hidden = ops.tile_along_beam(encoder_final_hidden, beam_size, dim=1)
    else:
        hidden = None

    constant_seq_len = constant_hidden_masks.size(1) - constant_hidden_masks.sum(dim=1)
    vocab_masks, memory_masks = None, None
    if model in [BRIDGE]:
        schema_seq_len = schema_hidden_masks.size(1) - schema_hidden_masks.sum(dim=1)
        memory_inputs, m_table_masks, m_field_masks, m_value_masks = \
            compute_memory_inputs(constant_seq_len, schema_seq_len, table_masks)
        if db_scope is not None:
            # vocab_mask = ops.int_ones_var_cuda(decoder.vocab.size)
            # vocab_mask[decoder.vocab.to_idx('from')] = 1
            # vocab_mask[decoder.vocab.to_idx('(')] = 1
            # v_clause_mask, v_op_mask, v_join_mask, v_others_mask = get_vocab_cat_masks()
            memory_masks = ops.int_zeros_var_cuda([batch_size, memory_inputs.size(1)])

            table_pos, table_field_scopes = db_scope
            table_memory_pos = constant_seq_len.unsqueeze(1) * (table_pos > 0).long() + table_pos
            table_memory_pos = ops.tile_along_beam(table_memory_pos, beam_size)
            table_field_scopes = ops.tile_along_beam(table_field_scopes, beam_size)
            db_scope = (table_memory_pos, table_field_scopes)
    else:
        memory_inputs = None

    if model in [SEQ2SEQ_PG, BRIDGE]:
        encoder_hiddens = ops.tile_along_beam(encoder_hiddens, beam_size)
        encoder_masks = ops.tile_along_beam(encoder_masks, beam_size)
        if memory_masks is not None:
            # assert(vocab_mask is not None)
            # vocab_masks = ops.tile_along_beam(vocab_mask.unsqueeze(0), batch_size * beam_size)
            # v_clause_masks = ops.tile_along_beam(v_clause_mask.unsqueeze(0), batch_size * beam_size)
            # v_op_masks = ops.tile_along_beam(v_op_mask.unsqueeze(0), batch_size * beam_size)
            # v_join_masks = ops.tile_along_beam(v_join_mask.unsqueeze(0), batch_size * beam_size)
            # v_others_masks = ops.tile_along_beam(v_others_mask.unsqueeze(0), batch_size * beam_size)
            memory_masks = ops.tile_along_beam(memory_masks, beam_size)
            m_table_masks = ops.tile_along_beam(m_table_masks, beam_size)
            m_field_masks = ops.tile_along_beam(m_field_masks, beam_size)
            m_value_masks = ops.tile_along_beam(m_value_masks, beam_size)
        if memory_inputs is not None:
            constant_seq_len = ops.tile_along_beam(constant_seq_len, beam_size)
            memory_inputs = ops.tile_along_beam(memory_inputs, beam_size)
        if encoder_ptr_value_ids is not None:
            encoder_ptr_value_ids = ops.tile_along_beam(encoder_ptr_value_ids, beam_size)
        seq_p_pointers = None
        ptr_context = None
        seq_text_ptr_weights = None
    elif model == SEQ2SEQ:
        seq_text_ptr_weights = None
    else:
        raise NotImplementedError

    pred_score = 0
    outputs, hiddens = None, (None, None)

    for i in range(num_steps):
        if i > 0:
            if model in [BRIDGE]:
                # [batch_size, 1]
                vocab_mask = (input < decoder.vocab_size).long()
                point_mask = 1 - vocab_mask
                memory_pos = (input - decoder.vocab_size) * point_mask
                memory_input = ops.batch_lookup(memory_inputs, memory_pos, vector_output=False)
                input_ = vocab_mask * input + point_mask * memory_input
                digit_mask = ((input == digit_0_id) |
                       (input == digit_1_id) |
                       (input == digit_2_id) |
                       (input == digit_3_id) |
                       (input == digit_4_id) |
                       (input == digit_5_id) |
                       (input == digit_6_id) |
                       (input == digit_7_id) |
                       (input == digit_8_id) |
                       (input == digit_9_id) |
                       (input == digit_10_id) |
                       (input == digit_11_id) |
                       (input == digit_12_id) |
                       (input == digit_s0_id) |
                       (input == digit_s1_id) |
                       (input == digit_s2_id) |
                       (input == digit_s3_id) |
                       (input == digit_s4_id) |
                       (input == digit_s5_id))
                vocab_mask[digit_mask] = 0
                input_[digit_mask] = decoder.vocab.value_id
                if db_scope is not None:
                    # [full_size, 3 (table, field, value)]
                    input_types = ops.long_var_cuda([decoder.vocab.table_id,
                                                     decoder.vocab.field_id,
                                                     decoder.vocab.value_id]).unsqueeze(0).expand([input_.size(0), 3])
                    # [full_size, 4 (vocab, table, field, value)]
                    input_type = torch.cat([vocab_mask, (input_ == input_types).long()], dim=1)
                    # [full_size, max_num_tables], [full_size, max_num_tables, max_num_fields_per_table]
                    table_memory_pos, table_field_scopes = db_scope
                    # update vocab masks
                    # vocab_masks = torch.index_select(vocab_masks, 0, beam_offset)
                    # update memory masks
                    m_field_masks = torch.index_select(m_field_masks, 0, beam_offset)
                    # [full_size, max_num_tables]
                    table_input_mask = (memory_pos == table_memory_pos)
                    if table_input_mask.max() > 0:
                        # [full_size, 1, max_num_fields_per_table]
                        db_scope_update_idx, _ = ops.batch_binary_lookup_3D(
                            table_field_scopes, table_input_mask, pad_value=0)
                        assert(db_scope_update_idx.size(1) == 1)
                        db_scope_update_idx.squeeze_(1)
                        db_scope_update_mask = (db_scope_update_idx > 0)
                        db_scope_update_idx = constant_seq_len.unsqueeze(1) + db_scope_update_idx
                        # db_scope_update: [full_size, memory_seq_len] binary mask in which the newly included table
                        # fields are set to 1 and the rest are set to 0
                        # assert (db_scope_update.max() <= 1)
                        # *
                        m_field_masks.scatter_(index=constant_seq_len.unsqueeze(1),
                                               src=ops.int_ones_var_cuda([batch_size*beam_size, 1]), dim=1)
                        m_field_masks.scatter_add_(index=db_scope_update_idx, src=db_scope_update_mask.long(), dim=1)
                        m_field_masks = (m_field_masks > 0).long()
                    # Heuristics:
                    # - table/field only appear after SQL keywords
                    # - value only appear after SQL keywords or other value token
                    memory_masks = input_type[:, 0].unsqueeze(1) * m_table_masks + \
                                   input_type[:, 0].unsqueeze(1) * m_field_masks + \
                                   (input_type[:, 0] + input_type[:, 3]).unsqueeze(1) * m_value_masks
                    # print(input_type[0])
                    # print(memory_masks[0])
                    # import pdb
                    # pdb.set_trace()
            elif model in [SEQ2SEQ_PG, SEQ2SEQ]:
                input_ = decoder.get_input_feed(input)
            else:
                input_ = input
            input_embedded = decoder_embeddings(input_)
        else:
            if start_embedded is None:
                input = ops.int_fill_var_cuda([full_size, 1], start_id)
                input_embedded = decoder_embeddings(input)
            else:
                raise NotImplementedError
        if model in [BRIDGE]:
            output, hidden, ptr_context = decoder(
                input_embedded,
                hidden,
                encoder_hiddens,
                encoder_masks,
                ptr_context,
                vocab_masks=vocab_masks,
                memory_masks=memory_masks,
                encoder_ptr_value_ids=encoder_ptr_value_ids,
                last_output=input)
        elif model in [SEQ2SEQ_PG]:
            output, hidden, ptr_context = decoder(
                input_embedded,
                hidden,
                encoder_hiddens,
                encoder_masks,
                ptr_context,
                encoder_ptr_value_ids=encoder_ptr_value_ids,
                last_output=input)
        elif model == SEQ2SEQ:
            output, hidden, text_ptr_weights = decoder(input_embedded, hidden, encoder_hiddens, encoder_masks)
        else:
            raise NotImplementedError

        # [full_size, vocab_size]
        output.squeeze_(1)
        vocab_size = output.size(1)

        seq_len += (1 - seen_eos.float())
        n_len_norm_factor = torch.pow(5 + seq_len, alpha) / np.power(5 + 1, alpha)
        # [full_size, vocab_size]
        if i == 0:
            raw_scores = \
                output + (ops.arange_cuda(beam_size).repeat(batch_size) > 0).float().unsqueeze(1) * (-ops.HUGE_INT)
        else:
            raw_scores = (pred_score * len_norm_factor + output * (1 - seen_eos.float())) / n_len_norm_factor
            eos_mask = ops.ones_var_cuda([1, vocab_size])
            eos_mask[0, eos_id] = 0
            raw_scores += (seen_eos.float() * eos_mask) * (-ops.HUGE_INT)

        len_norm_factor = n_len_norm_factor
        # [batch_size, beam_size * vocab_size]
        raw_scores = raw_scores.view(batch_size, beam_size * vocab_size)
        # [batch_size, beam_size]
        log_pred_prob, pred_idx = torch.topk(raw_scores, beam_size, dim=1)
        # [full_size]
        beam_offset = (pred_idx // vocab_size + ops.arange_cuda(batch_size).unsqueeze(1) * beam_size).view(-1)
        # [full_size, 1]
        pred_idx = (pred_idx % vocab_size).view(full_size, 1)
        log_pred_prob = log_pred_prob.view(full_size, 1)

        # update search history and save output
        # [num_layers*num_directions, full_size, hidden_dim]
        hidden = offset_hidden(hidden, beam_offset)
        # [num_layers*num_directions, full_size, seq_len, hidden_dim]
        if decoder.return_hiddens:
            hiddens = (
                update_beam_search_history(hiddens[0], hidden[0].unsqueeze(2), beam_offset, 1, 2),
                update_beam_search_history(hiddens[1], hidden[1].unsqueeze(2), beam_offset, 1, 2)
            )
        if outputs is not None:
            seq_len = torch.index_select(seq_len, 0, beam_offset)
            len_norm_factor = torch.index_select(len_norm_factor, 0, beam_offset)
            seen_eos = torch.index_select(seen_eos, 0, beam_offset)
        seen_eos = seen_eos | (pred_idx == eos_id)
        outputs = update_beam_search_history(outputs, pred_idx, beam_offset, 0, 1)
        pred_score = log_pred_prob

        input = pred_idx

        # save attention weights for interpretation and sanity checking
        if model in [SEQ2SEQ_PG, BRIDGE]:
            ptr_context = (torch.index_select(ptr_context[0], 0, beam_offset),
                           torch.index_select(ptr_context[1], 0, beam_offset))
            seq_text_ptr_weights = update_beam_search_history(
                seq_text_ptr_weights, ptr_context[1], beam_offset, 0, 2)
            seq_p_pointers = update_beam_search_history(
                seq_p_pointers, ptr_context[0].squeeze(2), beam_offset, 0, 1)
        elif model == SEQ2SEQ:
            seq_text_ptr_weights = update_beam_search_history(
                seq_text_ptr_weights, text_ptr_weights, beam_offset, 0, 2, offset_state=True)
        else:
            raise NotImplementedError

    if model in [SEQ2SEQ_PG, BRIDGE]:
        output_obj = outputs, pred_score, seq_p_pointers, seq_text_ptr_weights, seq_len
    elif model in [SEQ2SEQ]:
        output_obj = outputs, pred_score, seq_text_ptr_weights, seq_len
    else:
        raise NotImplementedError

    if decoder.return_hiddens:
        hidden_dim = hiddens[0].size(3)
        return output_obj, (hiddens[0].view(-1, batch_size, beam_size, num_steps, hidden_dim),
                            hiddens[1].view(-1, batch_size, beam_size, num_steps, hidden_dim))
    # elif return_final_hidden:
    #     hidden_dim = hidden[0].size(3)
    #     return output_obj, (hidden[0].view(-1, batch_size, beam_size, hidden_dim),
    #                         hidden[1].view(-1, batch_size, beam_size, hidden_dim))
    else:
        return output_obj
예제 #4
0
def ensemble_beam_search(sps, encoder_ptr_input_ids, encoder_ptr_value_ids,
                         text_masks, schema_masks, feature_ids, graphs,
                         transformer_output_value_masks, schema_memory_masks):
    with torch.no_grad():
        inputs, input_masks = encoder_ptr_input_ids
        if sps[0].pretrained_transformer:
            segment_ids, position_ids = sps[0].get_segment_and_position_ids(
                inputs)
        inputs_embedded = []
        encoder_hiddens, hidden = [], []
        for i, sp in enumerate(sps):
            if sp.pretrained_transformer:
                inputs_embedded_, _ = sp.encoder_embeddings(
                    inputs,
                    input_masks,
                    segments=segment_ids,
                    position_ids=position_ids)
            else:
                inputs_embedded_ = sp.encoder_embeddings(inputs)
            encoder_hiddens_, encoder_hidden_masks, constant_hidden_masks, schema_hidden_masks, hidden_ = \
                sp.encoder(inputs_embedded_,
                           input_masks,
                           text_masks,
                           schema_masks,
                           feature_ids,
                           transformer_output_value_masks)
            inputs_embedded.append(inputs_embedded_)
            encoder_hiddens.append(encoder_hiddens_)
            hidden.append(hidden_)

        table_masks, _ = feature_ids[3]
        table_pos, _ = feature_ids[4]
        if table_pos is not None:
            table_field_scope, _ = feature_ids[5]
            db_scope = (table_pos, table_field_scope)
        else:
            db_scope = None

        alpha = sps[0].bs_alpha
        model = sps[0].model_id
        num_models = len(sps)
        num_steps = sps[0].max_out_seq_len
        beam_size = sps[0].beam_size
        batch_size = encoder_hiddens[0].size(0)
        full_size = batch_size * beam_size

        start_id = sps[0].decoder.vocab.start_id
        eos_id = sps[0].decoder.vocab.eos_id
        table_id = sps[0].decoder.vocab.table_id
        field_id = sps[0].decoder.vocab.field_id
        value_id = sps[0].decoder.vocab.value_id
        vocab_size = sps[0].decoder.vocab_size
        seen_eos = ops.byte_zeros_var_cuda([full_size, 1])
        seq_len = 0
        start_embedded = None
        if type(hidden[-1]) is tuple:
            assert (len(hidden[-1]) == 2)
            hidden = [(ops.tile_along_beam(x, beam_size, dim=1),
                       ops.tile_along_beam(y, beam_size, dim=1))
                      for x, y in hidden]
        elif type(hidden[-1]) is not None:
            hidden = [ops.tile_along_beam(x, beam_size, dim=1) for x in hidden]

        constant_seq_len = constant_hidden_masks.size(
            1) - constant_hidden_masks.sum(dim=1)
        vocab_masks, memory_masks = None, None
        if model in [BRIDGE]:
            schema_seq_len = schema_hidden_masks.size(
                1) - schema_hidden_masks.sum(dim=1)
            memory_inputs, m_table_masks, m_field_masks, m_value_masks = \
                compute_memory_inputs(batch_size, constant_seq_len, schema_seq_len, table_masks,
                                      table_id, field_id, value_id)
            if db_scope is not None:
                # vocab_mask = ops.int_ones_var_cuda(decoder.vocab.size)
                # vocab_mask[decoder.vocab.to_idx('from')] = 1
                # vocab_mask[decoder.vocab.to_idx('(')] = 1
                # v_clause_mask, v_op_mask, v_join_mask, v_others_mask = get_vocab_cat_masks()
                memory_masks = ops.int_zeros_var_cuda(
                    [batch_size, memory_inputs.size(1)])

                table_pos, table_field_scopes = db_scope
                table_memory_pos = constant_seq_len.unsqueeze(1) * (
                    table_pos > 0).long() + table_pos
                table_memory_pos = ops.tile_along_beam(table_memory_pos,
                                                       beam_size)
                table_field_scopes = ops.tile_along_beam(
                    table_field_scopes, beam_size)
                db_scope = (table_memory_pos, table_field_scopes)
        else:
            memory_inputs = None

        if model in [SEQ2SEQ_PG, BRIDGE]:
            encoder_hiddens = [
                ops.tile_along_beam(x, beam_size) for x in encoder_hiddens
            ]
            encoder_masks = ops.tile_along_beam(encoder_hidden_masks,
                                                beam_size)
            if table_masks is not None:
                table_masks = ops.tile_along_beam(table_masks, beam_size)
            if memory_masks is not None:
                # assert(vocab_mask is not None)
                # vocab_masks = ops.tile_along_beam(vocab_mask.unsqueeze(0), batch_size * beam_size)
                # v_clause_masks = ops.tile_along_beam(v_clause_mask.unsqueeze(0), batch_size * beam_size)
                # v_op_masks = ops.tile_along_beam(v_op_mask.unsqueeze(0), batch_size * beam_size)
                # v_join_masks = ops.tile_along_beam(v_join_mask.unsqueeze(0), batch_size * beam_size)
                # v_others_masks = ops.tile_along_beam(v_others_mask.unsqueeze(0), batch_size * beam_size)
                memory_masks = ops.tile_along_beam(memory_masks, beam_size)
                m_table_masks = ops.tile_along_beam(m_table_masks, beam_size)
                m_field_masks = ops.tile_along_beam(m_field_masks, beam_size)
                m_value_masks = ops.tile_along_beam(m_value_masks, beam_size)
            if memory_inputs is not None:
                constant_seq_len = ops.tile_along_beam(constant_seq_len,
                                                       beam_size)
                memory_inputs = ops.tile_along_beam(memory_inputs, beam_size)
            if encoder_ptr_value_ids is not None:
                encoder_ptr_value_ids = ops.tile_along_beam(
                    encoder_ptr_value_ids, beam_size)
            seq_p_pointers = [None for _ in range(num_models)]
            ptr_context = [None for _ in range(num_models)]
            seq_text_ptr_weights = [None for _ in range(num_models)]
        elif model == SEQ2SEQ:
            seq_text_ptr_weights = [None for _ in range(num_models)]
        else:
            raise NotImplementedError

        pred_score = 0
        outputs = None
        hiddens = [(None, None) for _ in range(num_models)]

        for step_id in range(num_steps):
            if step_id > 0:
                if model in [BRIDGE]:
                    # [batch_size, 1]
                    vocab_mask = (input < vocab_size).long()
                    point_mask = 1 - vocab_mask
                    memory_pos = (input - vocab_size) * point_mask
                    memory_input = ops.batch_lookup(memory_inputs,
                                                    memory_pos,
                                                    vector_output=False)
                    input_ = vocab_mask * input + point_mask * memory_input
                    if db_scope is not None:
                        # [full_size, 3 (table, field, value)]
                        input_types = ops.long_var_cuda([
                            table_id, field_id, value_id
                        ]).unsqueeze(0).expand([input_.size(0), 3])
                        # [full_size, 4 (vocab, table, field, value)]
                        input_type = torch.cat(
                            [vocab_mask, (input_ == input_types).long()],
                            dim=1)
                        # [full_size, max_num_tables], [full_size, max_num_tables, max_num_fields_per_table]
                        table_memory_pos, table_field_scopes = db_scope
                        # update vocab masks
                        # vocab_masks = torch.index_select(vocab_masks, 0, beam_offset)
                        # update memory masks
                        m_field_masks = torch.index_select(
                            m_field_masks, 0, beam_offset)
                        # [full_size, max_num_tables]
                        table_input_mask = (memory_pos == table_memory_pos)
                        if table_input_mask.max() > 0:
                            # [full_size, 1, max_num_fields_per_table]
                            db_scope_update_idx, _ = ops.batch_binary_lookup_3D(
                                table_field_scopes,
                                table_input_mask,
                                pad_value=0)
                            assert (db_scope_update_idx.size(1) == 1)
                            db_scope_update_idx.squeeze_(1)
                            db_scope_update_mask = (db_scope_update_idx > 0)
                            db_scope_update_idx = constant_seq_len.unsqueeze(
                                1) + db_scope_update_idx
                            # db_scope_update: [full_size, memory_seq_len] binary mask in which the newly included table
                            # fields are set to 1 and the rest are set to 0
                            # assert (db_scope_update.max() <= 1)
                            # *
                            m_field_masks.scatter_(
                                index=constant_seq_len.unsqueeze(1),
                                src=ops.int_ones_var_cuda(
                                    [batch_size * beam_size, 1]),
                                dim=1)
                            m_field_masks.scatter_add_(
                                index=db_scope_update_idx,
                                src=db_scope_update_mask.long(),
                                dim=1)
                            m_field_masks = (m_field_masks > 0).long()
                        # Heuristics:
                        # - table/field only appear after SQL keywords
                        # - value only appear after SQL keywords or other value token
                        memory_masks = input_type[:, 0].unsqueeze(1) * m_table_masks + \
                                       input_type[:, 0].unsqueeze(1) * m_field_masks + \
                                       (input_type[:, 0] + input_type[:, 3]).unsqueeze(1) * m_value_masks
                        # print(input_type[0])
                        # print(memory_masks[0])
                        # import pdb
                        # pdb.set_trace()
                elif model in [SEQ2SEQ_PG, SEQ2SEQ]:
                    input_ = sps[0].decoder.get_input_feed(input)
                else:
                    input_ = input
                input_embedded = [sp.decoder_embeddings(input_) for sp in sps]
            else:
                if start_embedded is None:
                    input = ops.int_fill_var_cuda([full_size, 1], start_id)
                    input_embedded = [
                        sp.decoder_embeddings(input) for sp in sps
                    ]
                else:
                    raise NotImplementedError
            # print(step_id)
            # import pdb
            # pdb.set_trace()
            output, hidden_local, ptr_context_local, text_ptr_weights = [], [], [], []
            for i, sp in enumerate(sps):
                if model in [BRIDGE]:
                    output_, hidden_, ptr_context_ = sp.decoder(
                        input_embedded[i],
                        hidden[i],
                        encoder_hiddens[i],
                        encoder_masks,
                        ptr_context[i],
                        vocab_masks=vocab_masks,
                        memory_masks=memory_masks,
                        encoder_ptr_value_ids=encoder_ptr_value_ids,
                        last_output=input)
                    output.append(output_)
                    hidden_local.append(hidden_)
                    ptr_context_local.append(ptr_context_)
                elif model in [SEQ2SEQ_PG]:
                    output_, hidden_, ptr_context_ = sp.decoder(
                        input_embedded[i],
                        hidden[i],
                        encoder_hiddens[i],
                        encoder_masks,
                        ptr_context[i],
                        encoder_ptr_value_ids=encoder_ptr_value_ids,
                        last_output=input)
                    output.append(output_)
                    hidden_local.append(hidden_)
                    ptr_context_local.append(ptr_context_)
                elif model == SEQ2SEQ:
                    output_, hidden_, text_ptr_weights_ = sp.decoder(
                        input_embedded[i], hidden[i], encoder_hiddens[i],
                        encoder_masks)
                    output.append(output_)
                    hidden_local.append(hidden_)
                    text_ptr_weights.append(text_ptr_weights_)
                else:
                    raise NotImplementedError

            # [full_size, vocab_size]
            # Average the probability of the ensemble
            output = torch.mean(torch.stack(output), dim=0)
            output.squeeze_(1)
            out_vocab_size = output.size(1)

            seq_len += (1 - seen_eos.float())
            n_len_norm_factor = torch.pow(5 + seq_len, alpha) / np.power(
                5 + 1, alpha)
            # [full_size, vocab_size]
            if step_id == 0:
                raw_scores = \
                    output + (ops.arange_cuda(beam_size).repeat(batch_size) > 0).float().unsqueeze(1) * (-ops.HUGE_INT)
            else:
                raw_scores = (pred_score * len_norm_factor + output *
                              (1 - seen_eos.float())) / n_len_norm_factor
                eos_mask = ops.ones_var_cuda([1, out_vocab_size])
                eos_mask[0, eos_id] = 0
                raw_scores += (seen_eos.float() * eos_mask) * (-ops.HUGE_INT)

            len_norm_factor = n_len_norm_factor
            # [batch_size, beam_size * vocab_size]
            raw_scores = raw_scores.view(batch_size,
                                         beam_size * out_vocab_size)
            # [batch_size, beam_size]
            log_pred_prob, pred_idx = torch.topk(raw_scores, beam_size, dim=1)
            # [full_size]
            beam_offset = (
                pred_idx // out_vocab_size +
                ops.arange_cuda(batch_size).unsqueeze(1) * beam_size).view(-1)
            # [full_size, 1]
            pred_idx = (pred_idx % out_vocab_size).view(full_size, 1)
            log_pred_prob = log_pred_prob.view(full_size, 1)

            # update search history and save output
            # [num_layers*num_directions, full_size, hidden_dim]
            hidden = [offset_hidden(x, beam_offset) for x in hidden_local]
            # [num_layers*num_directions, full_size, seq_len, hidden_dim]
            if sps[0].decoder.return_hiddens:
                hiddens = [
                    (update_beam_search_history(hiddens[i][0],
                                                hidden[i][0].unsqueeze(2),
                                                beam_offset, 1, 2),
                     update_beam_search_history(hiddens[i][1],
                                                hidden[i][1].unsqueeze(2),
                                                beam_offset, 1, 2))
                    for i in range(num_models)
                ]
            if outputs is not None:
                seq_len = torch.index_select(seq_len, 0, beam_offset)
                len_norm_factor = torch.index_select(len_norm_factor, 0,
                                                     beam_offset)
                seen_eos = torch.index_select(seen_eos, 0, beam_offset)
            seen_eos = seen_eos | (pred_idx == eos_id)
            outputs = update_beam_search_history(outputs, pred_idx,
                                                 beam_offset, 0, 1)
            pred_score = log_pred_prob

            input = pred_idx

            # save attention weights for interpretation and sanity checking
            if model in [SEQ2SEQ_PG, BRIDGE]:
                ptr_context = [(torch.index_select(ptr_context_local[i][0], 0,
                                                   beam_offset),
                                torch.index_select(ptr_context_local[i][1], 0,
                                                   beam_offset))
                               for i in range(num_models)]
                seq_text_ptr_weights = [
                    update_beam_search_history(seq_text_ptr_weights[i],
                                               ptr_context[i][1], beam_offset,
                                               0, 2) for i in range(num_models)
                ]
                seq_p_pointers = [
                    update_beam_search_history(seq_p_pointers[i],
                                               ptr_context[i][0].squeeze(2),
                                               beam_offset, 0, 1)
                    for i in range(num_models)
                ]
            elif model == SEQ2SEQ:
                seq_text_ptr_weights = [
                    update_beam_search_history(seq_text_ptr_weights[i],
                                               text_ptr_weights[i],
                                               beam_offset,
                                               0,
                                               2,
                                               offset_state=True)
                    for i in range(num_models)
                ]
            else:
                raise NotImplementedError

        if model in [SEQ2SEQ_PG, BRIDGE]:
            output_obj = outputs, pred_score, seq_p_pointers[
                0], seq_text_ptr_weights[0], seq_len
        elif model in [SEQ2SEQ]:
            output_obj = outputs, pred_score, seq_text_ptr_weights[0], seq_len
        else:
            raise NotImplementedError

        if sps[0].decoder.return_hiddens:
            hidden_dim = hiddens[0].size(3)
            return output_obj, (hiddens[0].view(-1, batch_size, beam_size,
                                                num_steps, hidden_dim),
                                hiddens[1].view(-1, batch_size, beam_size,
                                                num_steps, hidden_dim))
        # elif return_final_hidden:
        #     hidden_dim = hidden[0].size(3)
        #     return output_obj, (hidden[0].view(-1, batch_size, beam_size, hidden_dim),
        #                         hidden[1].view(-1, batch_size, beam_size, hidden_dim))
        else:
            return output_obj