Beispiel #1
0
    def get_predictions_and_loss(self,
                                 input_ids,
                                 input_mask,
                                 speaker_ids,
                                 sentence_len,
                                 genre,
                                 sentence_map,
                                 is_training,
                                 gold_starts=None,
                                 gold_ends=None,
                                 gold_mention_cluster_map=None):
        """ Model and input are already on the device """
        device = self.device
        conf = self.config

        do_loss = False
        if gold_mention_cluster_map is not None:
            assert gold_starts is not None
            assert gold_ends is not None
            do_loss = True

        # Get token emb
        mention_doc, _ = self.bert(
            input_ids,
            attention_mask=input_mask)  # [num seg, num max tokens, emb size]
        input_mask = input_mask.to(torch.bool)
        mention_doc = mention_doc[input_mask]
        speaker_ids = speaker_ids[input_mask]
        num_words = mention_doc.shape[0]

        # Get candidate span
        sentence_indices = sentence_map  # [num tokens]
        candidate_starts = torch.unsqueeze(
            torch.arange(0, num_words, device=device),
            1).repeat(1, self.max_span_width)
        candidate_ends = candidate_starts + torch.arange(
            0, self.max_span_width, device=device)
        candidate_start_sent_idx = sentence_indices[candidate_starts]
        candidate_end_sent_idx = sentence_indices[torch.min(
            candidate_ends, torch.tensor(num_words - 1, device=device))]
        candidate_mask = (candidate_ends < num_words) & (
            candidate_start_sent_idx == candidate_end_sent_idx)
        candidate_starts, candidate_ends = candidate_starts[
            candidate_mask], candidate_ends[
                candidate_mask]  # [num valid candidates]
        num_candidates = candidate_starts.shape[0]

        # Get candidate labels
        if do_loss:
            same_start = (torch.unsqueeze(gold_starts, 1) == torch.unsqueeze(
                candidate_starts, 0))
            same_end = (torch.unsqueeze(gold_ends, 1) == torch.unsqueeze(
                candidate_ends, 0))
            same_span = (same_start & same_end).to(torch.long)
            candidate_labels = torch.matmul(
                torch.unsqueeze(gold_mention_cluster_map, 0).to(torch.float),
                same_span.to(torch.float))
            candidate_labels = torch.squeeze(candidate_labels.to(
                torch.long), 0)  # [num candidates]; non-gold span has label 0

        # Get span embedding
        span_start_emb, span_end_emb = mention_doc[
            candidate_starts], mention_doc[candidate_ends]
        candidate_emb_list = [span_start_emb, span_end_emb]
        if conf['use_features']:
            candidate_width_idx = candidate_ends - candidate_starts
            candidate_width_emb = self.emb_span_width(candidate_width_idx)
            candidate_width_emb = self.dropout(candidate_width_emb)
            candidate_emb_list.append(candidate_width_emb)
        # Use attended head or avg token
        candidate_tokens = torch.unsqueeze(
            torch.arange(0, num_words, device=device),
            0).repeat(num_candidates, 1)
        candidate_tokens_mask = (candidate_tokens >= torch.unsqueeze(
            candidate_starts, 1)) & (candidate_tokens <= torch.unsqueeze(
                candidate_ends, 1))
        if conf['model_heads']:
            token_attn = torch.squeeze(self.mention_token_attn(mention_doc), 1)
        else:
            token_attn = torch.ones(num_words,
                                    dtype=torch.float,
                                    device=device)  # Use avg if no attention
        candidate_tokens_attn_raw = torch.log(
            candidate_tokens_mask.to(torch.float)) + torch.unsqueeze(
                token_attn, 0)
        candidate_tokens_attn = nn.functional.softmax(
            candidate_tokens_attn_raw, dim=1)
        head_attn_emb = torch.matmul(candidate_tokens_attn, mention_doc)
        candidate_emb_list.append(head_attn_emb)
        candidate_span_emb = torch.cat(candidate_emb_list,
                                       dim=1)  # [num candidates, new emb size]

        # Get span score
        candidate_mention_scores = torch.squeeze(
            self.span_emb_score_ffnn(candidate_span_emb), 1)
        if conf['use_width_prior']:
            width_score = torch.squeeze(
                self.span_width_score_ffnn(self.emb_span_width_prior.weight),
                1)
            candidate_width_score = width_score[candidate_width_idx]
            candidate_mention_scores += candidate_width_score

        # Extract top spans
        candidate_idx_sorted_by_score = torch.argsort(
            candidate_mention_scores, descending=True).tolist()
        candidate_starts_cpu, candidate_ends_cpu = candidate_starts.tolist(
        ), candidate_ends.tolist()
        num_top_spans = int(
            min(conf['max_num_extracted_spans'],
                conf['top_span_ratio'] * num_words))
        selected_idx_cpu = self._extract_top_spans(
            candidate_idx_sorted_by_score, candidate_starts_cpu,
            candidate_ends_cpu, num_top_spans)
        assert len(selected_idx_cpu) == num_top_spans
        selected_idx = torch.tensor(selected_idx_cpu, device=device)
        top_span_starts, top_span_ends = candidate_starts[
            selected_idx], candidate_ends[selected_idx]
        top_span_emb = candidate_span_emb[selected_idx]
        top_span_cluster_ids = candidate_labels[
            selected_idx] if do_loss else None
        top_span_mention_scores = candidate_mention_scores[selected_idx]

        # Coarse pruning on each mention's antecedents
        max_top_antecedents = min(num_top_spans, conf['max_top_antecedents'])
        top_span_range = torch.arange(0, num_top_spans, device=device)
        antecedent_offsets = torch.unsqueeze(
            top_span_range, 1) - torch.unsqueeze(top_span_range, 0)
        antecedent_mask = (antecedent_offsets >= 1)
        pairwise_mention_score_sum = torch.unsqueeze(
            top_span_mention_scores, 1) + torch.unsqueeze(
                top_span_mention_scores, 0)
        source_span_emb = self.dropout(self.coarse_bilinear(top_span_emb))
        target_span_emb = self.dropout(torch.transpose(top_span_emb, 0, 1))
        pairwise_coref_scores = torch.matmul(source_span_emb, target_span_emb)
        pairwise_fast_scores = pairwise_mention_score_sum + pairwise_coref_scores
        pairwise_fast_scores += torch.log(antecedent_mask.to(torch.float))
        if conf['use_distance_prior']:
            distance_score = torch.squeeze(
                self.antecedent_distance_score_ffnn(
                    self.dropout(self.emb_antecedent_distance_prior.weight)),
                1)
            bucketed_distance = util.bucket_distance(antecedent_offsets)
            antecedent_distance_score = distance_score[bucketed_distance]
            pairwise_fast_scores += antecedent_distance_score
        top_pairwise_fast_scores, top_antecedent_idx = torch.topk(
            pairwise_fast_scores, k=max_top_antecedents)
        top_antecedent_mask = util.batch_select(
            antecedent_mask, top_antecedent_idx,
            device)  # [num top spans, max top antecedents]
        top_antecedent_offsets = util.batch_select(antecedent_offsets,
                                                   top_antecedent_idx, device)

        # Slow mention ranking
        if conf['fine_grained']:
            same_speaker_emb, genre_emb, seg_distance_emb, top_antecedent_distance_emb = None, None, None, None
            if conf['use_metadata']:
                top_span_speaker_ids = speaker_ids[top_span_starts]
                top_antecedent_speaker_id = top_span_speaker_ids[
                    top_antecedent_idx]
                same_speaker = torch.unsqueeze(top_span_speaker_ids,
                                               1) == top_antecedent_speaker_id
                same_speaker_emb = self.emb_same_speaker(
                    same_speaker.to(torch.long))
                genre_emb = self.emb_genre(genre)
                genre_emb = torch.unsqueeze(torch.unsqueeze(genre_emb, 0),
                                            0).repeat(num_top_spans,
                                                      max_top_antecedents, 1)
            if conf['use_segment_distance']:
                num_segs, seg_len = input_ids.shape[0], input_ids.shape[1]
                token_seg_ids = torch.arange(
                    0, num_segs,
                    device=device).unsqueeze(1).repeat(1, seg_len)
                token_seg_ids = token_seg_ids[input_mask]
                top_span_seg_ids = token_seg_ids[top_span_starts]
                top_antecedent_seg_ids = token_seg_ids[
                    top_span_starts[top_antecedent_idx]]
                top_antecedent_seg_distance = torch.unsqueeze(
                    top_span_seg_ids, 1) - top_antecedent_seg_ids
                top_antecedent_seg_distance = torch.clamp(
                    top_antecedent_seg_distance, 0,
                    self.config['max_training_sentences'] - 1)
                seg_distance_emb = self.emb_segment_distance(
                    top_antecedent_seg_distance)
            if conf['use_features']:  # Antecedent distance
                top_antecedent_distance = util.bucket_distance(
                    top_antecedent_offsets)
                top_antecedent_distance_emb = self.emb_top_antecedent_distance(
                    top_antecedent_distance)

            for depth in range(conf['coref_depth']):
                top_antecedent_emb = top_span_emb[
                    top_antecedent_idx]  # [num top spans, max top antecedents, emb size]
                feature_list = []
                if conf['use_metadata']:  # speaker, genre
                    feature_list.append(same_speaker_emb)
                    feature_list.append(genre_emb)
                if conf['use_segment_distance']:
                    feature_list.append(seg_distance_emb)
                if conf['use_features']:  # Antecedent distance
                    feature_list.append(top_antecedent_distance_emb)
                feature_emb = torch.cat(feature_list, dim=2)
                feature_emb = self.dropout(feature_emb)
                target_emb = torch.unsqueeze(top_span_emb,
                                             1).repeat(1, max_top_antecedents,
                                                       1)
                similarity_emb = target_emb * top_antecedent_emb
                pair_emb = torch.cat([
                    target_emb, top_antecedent_emb, similarity_emb, feature_emb
                ], 2)
                top_pairwise_slow_scores = torch.squeeze(
                    self.coref_score_ffnn(pair_emb), 2)
                top_pairwise_scores = top_pairwise_slow_scores + top_pairwise_fast_scores
                if conf['higher_order'] == 'cluster_merging':
                    cluster_merging_scores = ho.cluster_merging(
                        top_span_emb,
                        top_antecedent_idx,
                        top_pairwise_scores,
                        self.emb_cluster_size,
                        self.cluster_score_ffnn,
                        None,
                        self.dropout,
                        device=device,
                        reduce=conf['cluster_reduce'],
                        easy_cluster_first=conf['easy_cluster_first'])
                    break
                elif depth != conf['coref_depth'] - 1:
                    if conf['higher_order'] == 'attended_antecedent':
                        refined_span_emb = ho.attended_antecedent(
                            top_span_emb, top_antecedent_emb,
                            top_pairwise_scores, device)
                    elif conf['higher_order'] == 'max_antecedent':
                        refined_span_emb = ho.max_antecedent(
                            top_span_emb, top_antecedent_emb,
                            top_pairwise_scores, device)
                    elif conf['higher_order'] == 'entity_equalization':
                        refined_span_emb = ho.entity_equalization(
                            top_span_emb, top_antecedent_emb,
                            top_antecedent_idx, top_pairwise_scores, device)
                    elif conf['higher_order'] == 'span_clustering':
                        refined_span_emb = ho.span_clustering(
                            top_span_emb, top_antecedent_idx,
                            top_pairwise_scores, self.span_attn_ffnn, device)

                    gate = self.gate_ffnn(
                        torch.cat([top_span_emb, refined_span_emb], dim=1))
                    gate = torch.sigmoid(gate)
                    top_span_emb = gate * refined_span_emb + (
                        1 -
                        gate) * top_span_emb  # [num top spans, span emb size]
        else:
            top_pairwise_scores = top_pairwise_fast_scores  # [num top spans, max top antecedents]

        if not do_loss:
            if conf['fine_grained'] and conf[
                    'higher_order'] == 'cluster_merging':
                top_pairwise_scores += cluster_merging_scores
            top_antecedent_scores = torch.cat(
                [
                    torch.zeros(num_top_spans, 1, device=device),
                    top_pairwise_scores
                ],
                dim=1)  # [num top spans, max top antecedents + 1]
            return candidate_starts, candidate_ends, candidate_mention_scores, top_span_starts, top_span_ends, top_antecedent_idx, top_antecedent_scores

        # Get gold labels
        top_antecedent_cluster_ids = top_span_cluster_ids[top_antecedent_idx]
        top_antecedent_cluster_ids += (top_antecedent_mask.to(
            torch.long) - 1) * 100000  # Mask id on invalid antecedents
        same_gold_cluster_indicator = (
            top_antecedent_cluster_ids == torch.unsqueeze(
                top_span_cluster_ids, 1))
        non_dummy_indicator = torch.unsqueeze(top_span_cluster_ids > 0, 1)
        pairwise_labels = same_gold_cluster_indicator & non_dummy_indicator
        dummy_antecedent_labels = torch.logical_not(
            pairwise_labels.any(dim=1, keepdims=True))
        top_antecedent_gold_labels = torch.cat(
            [dummy_antecedent_labels, pairwise_labels], dim=1)

        # Get loss
        top_antecedent_scores = torch.cat([
            torch.zeros(num_top_spans, 1, device=device), top_pairwise_scores
        ],
                                          dim=1)
        if conf['loss_type'] == 'marginalized':
            log_marginalized_antecedent_scores = torch.logsumexp(
                top_antecedent_scores +
                torch.log(top_antecedent_gold_labels.to(torch.float)),
                dim=1)
            log_norm = torch.logsumexp(top_antecedent_scores, dim=1)
            loss = torch.sum(log_norm - log_marginalized_antecedent_scores)
        elif conf['loss_type'] == 'hinge':
            top_antecedent_mask = torch.cat([
                torch.ones(num_top_spans, 1, dtype=torch.bool, device=device),
                top_antecedent_mask
            ],
                                            dim=1)
            top_antecedent_scores += torch.log(
                top_antecedent_mask.to(torch.float))
            highest_antecedent_scores, highest_antecedent_idx = torch.max(
                top_antecedent_scores, dim=1)
            gold_antecedent_scores = top_antecedent_scores + torch.log(
                top_antecedent_gold_labels.to(torch.float))
            highest_gold_antecedent_scores, highest_gold_antecedent_idx = torch.max(
                gold_antecedent_scores, dim=1)
            slack_hinge = 1 + highest_antecedent_scores - highest_gold_antecedent_scores
            # Calculate delta
            highest_antecedent_is_gold = (
                highest_antecedent_idx == highest_gold_antecedent_idx)
            mistake_false_new = (highest_antecedent_idx
                                 == 0) & torch.logical_not(
                                     dummy_antecedent_labels.squeeze())
            delta = ((3 - conf['false_new_delta']) / 2) * torch.ones(
                num_top_spans, dtype=torch.float, device=device)
            delta -= (1 - conf['false_new_delta']) * mistake_false_new.to(
                torch.float)
            delta *= torch.logical_not(highest_antecedent_is_gold).to(
                torch.float)
            loss = torch.sum(slack_hinge * delta)

        # Add mention loss
        if conf['mention_loss_coef']:
            gold_mention_scores = top_span_mention_scores[
                top_span_cluster_ids > 0]
            non_gold_mention_scores = top_span_mention_scores[
                top_span_cluster_ids == 0]
            loss_mention = -torch.sum(
                torch.log(torch.sigmoid(
                    gold_mention_scores))) * conf['mention_loss_coef']
            loss_mention += -torch.sum(
                torch.log(1 - torch.sigmoid(non_gold_mention_scores))
            ) * conf['mention_loss_coef']
            loss += loss_mention

        if conf['higher_order'] == 'cluster_merging':
            top_pairwise_scores += cluster_merging_scores
            top_antecedent_scores = torch.cat([
                torch.zeros(num_top_spans, 1, device=device),
                top_pairwise_scores
            ],
                                              dim=1)
            log_marginalized_antecedent_scores2 = torch.logsumexp(
                top_antecedent_scores +
                torch.log(top_antecedent_gold_labels.to(torch.float)),
                dim=1)
            log_norm2 = torch.logsumexp(top_antecedent_scores,
                                        dim=1)  # [num top spans]
            loss_cm = torch.sum(log_norm2 -
                                log_marginalized_antecedent_scores2)
            if conf['cluster_dloss']:
                loss += loss_cm
            else:
                loss = loss_cm

        # Debug
        if self.debug:
            if self.update_steps % 20 == 0:
                logger.info('---------debug step: %d---------' %
                            self.update_steps)
                # logger.info('candidates: %d; antecedents: %d' % (num_candidates, max_top_antecedents))
                logger.info('spans/gold: %d/%d; ratio: %.2f' %
                            (num_top_spans, (top_span_cluster_ids > 0).sum(),
                             (top_span_cluster_ids > 0).sum() / num_top_spans))
                if conf['mention_loss_coef']:
                    logger.info('mention loss: %.4f' % loss_mention)
                if conf['loss_type'] == 'marginalized':
                    logger.info(
                        'norm/gold: %.4f/%.4f' %
                        (torch.sum(log_norm),
                         torch.sum(log_marginalized_antecedent_scores)))
                else:
                    logger.info('loss: %.4f' % loss)
        self.update_steps += 1

        return [
            candidate_starts, candidate_ends, candidate_mention_scores,
            top_span_starts, top_span_ends, top_antecedent_idx,
            top_antecedent_scores
        ], loss
Beispiel #2
0
def cluster_merging(top_span_emb, top_antecedent_idx, top_antecedent_scores, emb_cluster_size, cluster_score_ffnn, cluster_transform, dropout, device, reduce='mean', easy_cluster_first=False):
    num_top_spans, max_top_antecedents = top_antecedent_idx.shape[0], top_antecedent_idx.shape[1]
    span_emb_size = top_span_emb.shape[-1]
    max_num_clusters = num_top_spans

    span_to_cluster_id = torch.zeros(num_top_spans, dtype=torch.long, device=device)  # id 0 as dummy cluster
    cluster_emb = torch.zeros(max_num_clusters, span_emb_size, dtype=torch.float, device=device)  # [max num clusters, emb size]
    num_clusters = 1  # dummy cluster
    cluster_sizes = torch.ones(max_num_clusters, dtype=torch.long, device=device)

    merge_order = torch.arange(0, num_top_spans)
    if easy_cluster_first:
        max_antecedent_scores, _ = torch.max(top_antecedent_scores, dim=1)
        merge_order = torch.argsort(max_antecedent_scores, descending=True)
    cluster_merging_scores = [None] * num_top_spans

    for i in merge_order.tolist():
        # Get cluster scores
        antecedent_cluster_idx = span_to_cluster_id[top_antecedent_idx[i]]
        antecedent_cluster_emb = cluster_emb[antecedent_cluster_idx]
        # antecedent_cluster_emb = dropout(cluster_transform(antecedent_cluster_emb))

        antecedent_cluster_size = cluster_sizes[antecedent_cluster_idx]
        antecedent_cluster_size = util.bucket_distance(antecedent_cluster_size)
        cluster_size_emb = dropout(emb_cluster_size(antecedent_cluster_size))

        span_emb = top_span_emb[i].unsqueeze(0).repeat(max_top_antecedents, 1)
        similarity_emb = span_emb * antecedent_cluster_emb
        pair_emb = torch.cat([span_emb, antecedent_cluster_emb, similarity_emb, cluster_size_emb], dim=1)  # [max top antecedents, pair emb size]
        cluster_scores = torch.squeeze(cluster_score_ffnn(pair_emb), 1)
        cluster_scores_mask = (antecedent_cluster_idx > 0).to(torch.float)
        cluster_scores *= cluster_scores_mask
        cluster_merging_scores[i] = cluster_scores

        # Get predicted antecedent
        antecedent_scores = top_antecedent_scores[i] + cluster_scores
        max_score, max_score_idx = torch.max(antecedent_scores, dim=0)
        if max_score < 0:
            continue  # Dummy antecedent
        max_antecedent_idx = top_antecedent_idx[i, max_score_idx]

        if not easy_cluster_first:  # Always add span to antecedent's cluster
            # Create antecedent cluster if needed
            antecedent_cluster_id = span_to_cluster_id[max_antecedent_idx]
            if antecedent_cluster_id == 0:
                antecedent_cluster_id = num_clusters
                span_to_cluster_id[max_antecedent_idx] = antecedent_cluster_id
                cluster_emb[antecedent_cluster_id] = top_span_emb[max_antecedent_idx]
                num_clusters += 1
            # Add span to cluster
            span_to_cluster_id[i] = antecedent_cluster_id
            _merge_span_to_cluster(cluster_emb, cluster_sizes, antecedent_cluster_id, top_span_emb[i], reduce=reduce)
        else:  # current span can be in cluster already
            antecedent_cluster_id = span_to_cluster_id[max_antecedent_idx]
            curr_span_cluster_id = span_to_cluster_id[i]
            if antecedent_cluster_id > 0 and curr_span_cluster_id > 0:
                # Merge two clusters
                span_to_cluster_id[max_antecedent_idx] = curr_span_cluster_id
                _merge_clusters(cluster_emb, cluster_sizes, antecedent_cluster_id, curr_span_cluster_id, reduce=reduce)
            elif curr_span_cluster_id > 0:
                # Merge antecedent to span's cluster
                span_to_cluster_id[max_antecedent_idx] = curr_span_cluster_id
                _merge_span_to_cluster(cluster_emb, cluster_sizes, curr_span_cluster_id, top_span_emb[max_antecedent_idx], reduce=reduce)
            else:
                # Create antecedent cluster if needed
                if antecedent_cluster_id == 0:
                    antecedent_cluster_id = num_clusters
                    span_to_cluster_id[max_antecedent_idx] = antecedent_cluster_id
                    cluster_emb[antecedent_cluster_id] = top_span_emb[max_antecedent_idx]
                    num_clusters += 1
                # Add span to cluster
                span_to_cluster_id[i] = antecedent_cluster_id
                _merge_span_to_cluster(cluster_emb, cluster_sizes, antecedent_cluster_id, top_span_emb[i], reduce=reduce)

    cluster_merging_scores = torch.stack(cluster_merging_scores, dim=0)
    return cluster_merging_scores