Ejemplo n.º 1
0
    def predict(self, k: int, query: torch.FloatTensor,
                memory_bank: torch.FloatTensor,
                memory_labels: torch.LongTensor):

        C = self.num_classes
        T = self.temperature
        B, _ = query.size()

        # Compute cosine similarity
        sim_matrix = torch.einsum(
            'bf,fm->bm', [query, memory_bank])  # (b, f) @ (f, M) -> (b, M)
        sim_weight, sim_indices = sim_matrix.sort(
            dim=1, descending=True)  # (b, M), (b, M)
        sim_weight, sim_indices = sim_weight[:, :
                                             k], sim_indices[:, :
                                                             k]  # (b, k), (b, k)
        sim_weight = (sim_weight / T).exp()  # (b, k)
        sim_labels = torch.gather(
            memory_labels.expand(B, -1),  # (1, M) -> (b, M)
            dim=1,
            index=sim_indices)  # (b, M)

        one_hot = torch.zeros(B * k, C, device=sim_labels.device)  # (bk, C)
        one_hot.scatter_(dim=-1, index=sim_labels.view(-1, 1),
                         value=1)  # (bk, C) <- scatter <- (bk, 1)
        pred = one_hot.view(B, k, C) * sim_weight.unsqueeze(
            dim=-1)  # (b, k, C) * (b, k, 1)
        pred = pred.sum(dim=1)  # (b, C)

        return pred.argsort(
            dim=-1, descending=True
        )  # (b, C); first column gives label of highest confidence
Ejemplo n.º 2
0
def _select_lmc_coefficients(lmc_coefficients: torch.Tensor,
                             indices: torch.LongTensor) -> torch.Tensor:
    """
    Given a list of indices for ... x N datapoints,
      select the row from lmc_coefficient that corresponds to each datapoint

    lmc_coefficients: torch.Tensor ... x num_latents x ... x num_tasks
    indices: torch.Tesnor ... x N
    """
    batch_shape = _mul_broadcast_shape(lmc_coefficients.shape[:-1],
                                       indices.shape[:-1])

    # We will use the left_interp helper to do the indexing
    lmc_coefficients = lmc_coefficients.expand(
        *batch_shape, lmc_coefficients.shape[-1])[..., None]
    indices = indices.expand(*batch_shape, indices.shape[-1])[..., None]
    res = left_interp(
        indices,
        torch.ones(indices.shape, dtype=torch.long, device=indices.device),
        lmc_coefficients,
    ).squeeze(-1)
    return res
Ejemplo n.º 3
0
    def forward(
            self,  # type: ignore
            question: Dict[str, torch.LongTensor],
            passage: Dict[str, torch.LongTensor],
            span_start: torch.LongTensor = None,
            span_end: torch.LongTensor = None,
            spans=None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        # pylint: disable=arguments-differ
        # Shape: (batch_size, num_passage, passage_length, embedding_dim)
        embedded_passage = self._text_field_embedder(passage)
        passage_mask = util.get_text_field_mask(passage, 1).float()

        # get some parameters
        cuda_device = embedded_passage.get_device()
        batch_size, num_passage, passage_length, embedding_dim = embedded_passage.size(
        )

        # when training, select randomly 2 passages from 4 passages each epoch
        if self.training:
            # sample training
            categorical_probs = Variable(
                torch.Tensor([0.4, 0.2, 0.2,
                              0.2]).unsqueeze(0).expand(batch_size,
                                                        4)).cuda(cuda_device)
            num_passage /= 2
            indices = torch.multinomial(categorical_probs, num_passage)
            # Shape: (batch_size, num_passage, passage_length, embedding_dim)
            embedded_passage = torch.gather(
                embedded_passage, 1,
                indices.unsqueeze(-1).unsqueeze(-1).expand(
                    batch_size, num_passage, passage_length, embedding_dim))
            # Shape: (batch_size, num_passage, passage_length)
            passage_mask = torch.gather(
                passage_mask, 1,
                indices.unsqueeze(-1).expand(batch_size, num_passage,
                                             passage_length))

        # Shape: (batch_size*num_passsage, passage_length, embedding_dim)
        embedded_passage = embedded_passage.view(-1, passage_length,
                                                 embedding_dim)
        embedded_passage = self._highway_layer(embedded_passage)
        # Shape: (batch_size*num_passage, passage_length)
        passage_mask = passage_mask.view(-1, passage_length)
        # Shape: (batch_size, question_length, embedding_dim)
        embedded_question = self._highway_layer(
            self._text_field_embedder(question))
        question_length = embedded_question.size(1)
        # Shape: (batch_size*num_passage, question_length, embedding_dim)
        embedded_question = embedded_question.unsqueeze(0).expand(
            num_passage, -1, -1, -1).contiguous().view(-1, question_length,
                                                       embedding_dim)
        # Shape: (batch_size, question_length)
        question_mask = util.get_text_field_mask(question).float()
        # Shape: (batch_size*num_passage, question_length)
        question_mask = question_mask.unsqueeze(0).expand(
            num_passage, -1, -1).contiguous().view(-1, question_length)

        question_lstm_mask = question_mask if self._mask_lstms else None
        passage_lstm_mask = passage_mask if self._mask_lstms else None

        encoded_question = self._dropout(
            self._phrase_layer(embedded_question, question_lstm_mask))
        encoded_passage = self._dropout(
            self._phrase_layer(embedded_passage, passage_lstm_mask))
        encoding_dim = encoded_question.size(-1)

        # Shape: (batch_size*num_passage, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(
            encoded_passage, encoded_question)
        # Shape: (batch_size*num_passage, passage_length, question_length)
        passage_question_attention = util.last_dim_softmax(
            passage_question_similarity, question_mask)
        # Shape: (batch_size*num_passage, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(
            encoded_question, passage_question_attention)

        # We replace masked values with something really negative here, so they don't affect the
        # max below.
        masked_similarity = util.replace_masked_values(
            passage_question_similarity, question_mask.unsqueeze(1), -1e7)
        # Shape: (batch_size*num_passage, passage_length)
        question_passage_similarity = masked_similarity.max(
            dim=-1)[0].squeeze(-1)
        # Shape: (batch_size*num_passage, passage_length)
        question_passage_attention = util.masked_softmax(
            question_passage_similarity, passage_mask)
        # Shape: (batch_size*num_passage, encoding_dim)
        question_passage_vector = util.weighted_sum(
            encoded_passage, question_passage_attention)
        # Shape: (batch_size*num_passage, passage_length, encoding_dim)
        tiled_question_passage_vector = question_passage_vector.unsqueeze(
            1).expand(batch_size * num_passage, passage_length, encoding_dim)
        # Shape: (batch_size*num_passage, passage_length, encoding_dim * 4)
        final_merged_passage = torch.cat([
            encoded_passage, passage_question_vectors,
            encoded_passage * passage_question_vectors,
            encoded_passage * tiled_question_passage_vector
        ],
                                         dim=-1)
        # Shape: (batch_size*num_passage, passage_length, encoding_dim)
        question_attended_passage = relu(
            self._linear_layer(final_merged_passage))

        # attach residual self-attention layer
        # Shape: (batch_size*num_passage, passage_length, encoding_dim)
        residual_passage = self._dropout(
            self._residual_encoder(self._dropout(question_attended_passage),
                                   passage_lstm_mask))
        # create mask for self-attention
        mask = passage_mask.resize(
            batch_size * num_passage, passage_length, 1) * passage_mask.resize(
                batch_size * num_passage, 1, passage_length)
        self_mask = Variable(
            torch.eye(passage_length,
                      passage_length).cuda(cuda_device)).resize(
                          1, passage_length, passage_length)
        mask = mask * (1 - self_mask)
        # Shape: (batch_size*num_passage, passage_length, passage_length)
        x_similarity = torch.matmul(residual_passage, self._w_x).unsqueeze(2)
        y_similarity = torch.matmul(residual_passage, self._w_y).unsqueeze(1)
        dot_similarity = torch.bmm(residual_passage * self._w_xy,
                                   residual_passage.transpose(1, 2))
        passage_self_similarity = dot_similarity + x_similarity + y_similarity
        # Shape: (batch_size*num_passage, passage_length, passage_length)
        passage_self_attention = util.last_dim_softmax(passage_self_similarity,
                                                       mask)
        # Shape: (batch_size*num_passage, passage_length, encoding_dim)
        passage_vectors = util.weighted_sum(residual_passage,
                                            passage_self_attention)
        # Shape: (batch_size*num_passage, passage_length, encoding_dim * 3)
        merged_passage = torch.cat([
            residual_passage, passage_vectors,
            residual_passage * passage_vectors
        ],
                                   dim=-1)
        # Shape: (batch_size*num_passage, passage_length, encoding_dim)
        self_attended_passage = relu(
            self._residual_linear_layer(merged_passage))

        # Shape: (batch_size*num_passage, passage_length, encoding_dim)
        mixed_passage = question_attended_passage + self_attended_passage

        # Shape: (batch_size*num_passage, passage_length, encoding_dim)
        encoded_span_start = self._dropout(
            self._span_start_encoder(mixed_passage, passage_lstm_mask))
        span_start_logits = self._span_start_predictor(
            encoded_span_start).squeeze(-1)
        span_start_probs = util.masked_softmax(span_start_logits, passage_mask)

        # Shape: (batch_size*num_passage, passage_length, encoding_dim * 2)
        concatenated_passage = torch.cat([mixed_passage, encoded_span_start],
                                         dim=-1)
        # Shape: (batch_size*num_passage, passage_length, encoding_dim)
        encoded_span_end = self._dropout(
            self._span_end_encoder(concatenated_passage, passage_lstm_mask))
        span_end_logits = self._span_end_predictor(encoded_span_end).squeeze(
            -1)
        span_end_probs = util.masked_softmax(span_end_logits, passage_mask)
        # Shape: (batch_size*num_passage, passage_length)
        span_start_logits = util.replace_masked_values(span_start_logits,
                                                       passage_mask, -1e7)
        span_end_logits = util.replace_masked_values(span_end_logits,
                                                     passage_mask, -1e7)

        # no answer option
        v_1 = util.weighted_sum(encoded_span_start, span_start_probs)
        v_2 = util.weighted_sum(encoded_span_end, span_end_probs)
        no_span_logits = self._no_answer_predictor(
            self_attended_passage).squeeze(-1)
        no_span_probs = util.masked_softmax(no_span_logits, passage_mask)
        v_3 = util.weighted_sum(self_attended_passage, no_span_probs)
        # Shape: (batch_size*num_passage, 1)
        z_score = self._feed_forward(torch.cat([v_1, v_2, v_3], dim=-1))
        # Shape: (batch_size*num_passage, passage_length, passage_length)
        span_start_logits_tiled = span_start_logits.unsqueeze(1).expand(
            -1, passage_length, passage_length)
        span_end_logits_tiled = span_end_logits.unsqueeze(-1).expand(
            -1, passage_length, passage_length)
        # Shape: (batch_size*num_passage, passage_length**2)
        span_logits = (span_start_logits_tiled + span_end_logits_tiled).view(
            -1, passage_length**2)
        # Shape: (batch_size*num_passage, passage_length**2)
        answer_mask = torch.bmm(passage_mask.unsqueeze(-1),
                                passage_mask.unsqueeze(1)).view(
                                    -1, passage_length**2)
        # Shape: (batch_size*num_passage, 1)
        no_answer_mask = Variable(
            torch.ones(batch_size *
                       num_passage).unsqueeze(-1)).cuda(cuda_device)
        # Shape: (batch_size*num_passage, passage_length**2)
        combined_mask = torch.cat([answer_mask, no_answer_mask], dim=1)
        # Shape: (batch_size*num_passage, passage_length**2 + 1)
        all_logits = torch.cat([span_logits, z_score], dim=-1)

        output_dict = {
            "passage_question_attention": passage_question_attention,
            "span_start_logits": span_start_logits,
            "span_start_probs": span_start_probs,
            "span_end_logits": span_end_logits,
            "span_end_probs": span_end_probs,
            #"best_span": best_span,
        }

        # compute loss
        if span_start is not None:
            if self.training:
                # Shape: (batch_size*num_passage)
                span_start = span_start.expand(
                    -1, num_passage).contiguous().view(-1)
                span_end = span_end.expand(-1,
                                           num_passage).contiguous().view(-1)
                # Shape: (batch_size*num_passage)
                indices = indices.view(-1)
                # Shape: (batch_size*num_passage)
                span_target = span_start * passage_length + span_end
                # no-answer label
                span_target[indices != 0] = passage_length**2
                loss = nll_loss(
                    util.masked_log_softmax(all_logits, combined_mask),
                    span_target)
            else:  # do not care of loss when validating
                loss = 0
            output_dict["loss"] = loss

        # Compute the EM and F1 on SQuAD and add the tokenized input to the output.
        if not self.training and metadata is not None:
            # Shape: (batch_size*num_passage, 3)
            best_span = self.get_best_span(span_start_logits, span_end_logits)
            # Shape: (batch_size, num_passage, 3)
            best_span = best_span.view(batch_size, num_passage, 3)

            output_dict['best_span_str'] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]['question_tokens'])
                passage_tokens.append(metadata[i]['passage_tokens'])
                passage_str = metadata[i]['original_passage']
                offsets = metadata[i]['token_offsets']
                all_passages = metadata[i]['all_passages']
                passage_offsets = metadata[i]['passage_offsetss']
                ####################
                #correct_passage = metadata[i]['correct_passage']
                _, max_id = torch.max(best_span[i, :, 2], dim=0)
                #if correct_passage == selected_passage:
                #    predicted_span = tuple(best_span[i].data.cpu().numpy())
                #    start_offset = offsets[int(predicted_span[0])][0]
                #    end_offset = offsets[int(predicted_span[1])][1]
                #    best_span_string = passage_str[start_offset:end_offset]
                predicted_span = tuple(best_span[i, max_id].data.cpu().numpy())
                start_offset = passage_offsets[max_id][int(
                    predicted_span[0])][0]
                end_offset = passage_offsets[max_id][int(predicted_span[1])][1]
                best_span_string = all_passages[max_id][
                    start_offset:end_offset]
                ####################
                output_dict['best_span_str'].append(best_span_string)
                answer_texts = metadata[i].get('answer_texts', [])
                if answer_texts:
                    self._squad_metrics(best_span_string, answer_texts)
            output_dict['question_tokens'] = question_tokens
            output_dict['passage_tokens'] = passage_tokens
        return output_dict
Ejemplo n.º 4
0
    def forward(self,
                question: Dict[str, torch.LongTensor],
                choice1_indexes: List[int] = None,
                choice2_indexes: List[int] = None,
                label: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None) -> torch.Tensor:

        self._debug -= 1
        input_ids = question['bert']

        # input_ids.size() == (batch_size, num_pairs, max_sentence_length)
        batch_size, num_pairs, _ = question['bert'].size()
        question_mask = (input_ids != 0).long()

        if self._train_comparison_layer:
            assert num_pairs == self._num_choices * (self._num_choices - 1)

        # Segment ids
        real_segment_ids = question['bert-type-ids'].clone()
        # Change the last 'SEP' to belong to the second answer (for symmetry)
        last_seps = (real_segment_ids.roll(-1) == 2) & (real_segment_ids == 1)
        real_segment_ids[last_seps] = 2
        # Update segment ids so that they are '1' for answers and '0' for the question
        real_segment_ids = (real_segment_ids == 0) | (real_segment_ids == 2)
        real_segment_ids = real_segment_ids.long()

        # TODO: How to extract last token pooled output if batch size != 1
        assert batch_size == 1

        # Run model
        encoded_layers, first_vectors_pooled_output = self._bert_model(input_ids=util.combine_initial_dims(input_ids),
                                            token_type_ids=util.combine_initial_dims(real_segment_ids),
                                            attention_mask=util.combine_initial_dims(question_mask),
                                            output_all_encoded_layers=self._all_layers)

        if self._use_comparative_bert:
            last_vectors_pooled_output = self._extract_last_token_pooled_output(encoded_layers, question_mask)
        else:
            last_vectors_pooled_output = None
        if self._all_layers:
            mixed_layer = self._scalar_mix(encoded_layers, question_mask)
            first_vectors_pooled_output = self._bert_model.pooler(mixed_layer)

        # Apply dropout
        first_vectors_pooled_output = self._dropout(first_vectors_pooled_output)
        if self._use_comparative_bert:
            last_vectors_pooled_output = self._dropout(last_vectors_pooled_output)

        # Classify
        if not self._use_comparative_bert:
            pair_label_logits = self._classifier(first_vectors_pooled_output)
        else:
            if self._use_bilinear_classifier:
                pair_label_logits = self._classifier(first_vectors_pooled_output, last_vectors_pooled_output)
            else:
                all_pooled_output = torch.cat((first_vectors_pooled_output, last_vectors_pooled_output), 1)
                pair_label_logits = self._classifier(all_pooled_output)

        pair_label_logits = pair_label_logits.view(-1, num_pairs)

        pair_label_probs = torch.sigmoid(pair_label_logits)

        output_dict = {}
        pair_label_probs_flat = pair_label_probs.squeeze(1)
        output_dict['pair_label_probs'] = pair_label_probs_flat.view(-1, num_pairs)
        output_dict['pair_label_logits'] = pair_label_logits
        output_dict['choice1_indexes'] = choice1_indexes
        output_dict['choice2_indexes'] = choice2_indexes

        if not self._train_comparison_layer:
            if label is not None:
                label = label.unsqueeze(1)
                label = label.expand(-1, num_pairs)
                relevant_pairs = (choice1_indexes == label) | (choice2_indexes == label)
                relevant_probs = pair_label_probs[relevant_pairs]
                choice1_is_the_label = (choice1_indexes == label)[relevant_pairs]
                # choice1_is_the_label = choice1_is_the_label.type_as(relevant_logits)

                loss = self._loss(relevant_probs, choice1_is_the_label.float())
                self._accuracy(relevant_probs >= 0.5, choice1_is_the_label)
                output_dict["loss"] = loss

            return output_dict
        else:
            choice_logits = self._comparison_layer_2(self._comparison_layer_1_activation(self._comparison_layer_1(
                pair_label_probs)))
            output_dict['choice_logits'] = choice_logits
            output_dict['choice_probs'] = torch.softmax(choice_logits, 1)
            output_dict['predicted_choice'] = torch.argmax(choice_logits, 1)

            if label is not None:
                loss = self._loss(choice_logits, label)
                self._accuracy(choice_logits, label)
                output_dict["loss"] = loss

        return output_dict