예제 #1
0
 def extract_vectors_from_markers(self, embbeds, location):
     stacked_embed = embbeds.view(-1, embbeds.size(-2), embbeds.size(-1))
     pt_tensor_from_list = torch.FloatTensor(location)
     indeces = util.combine_initial_dims(pt_tensor_from_list).long().to(
         self.device)
     value = self.extractor(stacked_embed, indeces)
     value = self.renorm_vector(value).to(self.device)
     return value
예제 #2
0
    def forward(self, inputs):

        original_size = inputs.size()
        inputs = util.combine_initial_dims(inputs)

        embedded = embedding(inputs, self.weight, padding_idx=self.padding_index)

        embedded = util.uncombine_initial_dims(embedded, original_size)

        return embedded
def get_select_embedding(sub_words_embedding, offsets):
    # offsets is (batch_size, d1, ..., dn, orig_sequence_length)
    offsets2d = util.combine_initial_dims(offsets)
    # now offsets is (batch_size * d1 * ... * dn, orig_sequence_length)
    range_vector = util.get_range_vector(offsets2d.size(0),
                                         device=util.get_device_of(sub_words_embedding)).unsqueeze(1)
    # selected embeddings is also (batch_size * d1 * ... * dn, orig_sequence_length)
    selected_embeddings = sub_words_embedding[range_vector, offsets2d]

    return util.uncombine_initial_dims(selected_embeddings, offsets.size())
    def forward(self, inputs):  # pylint: disable=arguments-differ
        # inputs may have extra dimensions (batch_size, d1, ..., dn, sequence_length),
        # but embedding expects (batch_size, sequence_length), so pass inputs to
        # util.combine_initial_dims (which is a no-op if there are no extra dimensions).
        # Remember the original size.
        original_size = inputs.size()
        inputs = util.combine_initial_dims(inputs)

        emb, glyph_classification_loss = self.glyph_embedding(inputs)

        # Now (if necessary) add back in the extra dimensions.
        embedded = util.uncombine_initial_dims(emb, original_size)

        return embedded, glyph_classification_loss
예제 #5
0
    def forward(self, inputs):
        original_size = inputs.size()
        inputs = util.combine_initial_dims(inputs)

        embedded = embedding(inputs, self.weight,
                             padding_idx=self.padding_index,
                             max_norm=self.max_norm,
                             norm_type=self.norm_type,
                             scale_grad_by_freq=self.scale_grad_by_freq,
                             sparse=self.sparse)

        embedded = util.uncombine_initial_dims(embedded, original_size)

        return embedded
예제 #6
0
    def forward(self, inputs):  # pylint: disable=arguments-differ
        original_size = inputs.size()
        inputs = util.combine_initial_dims(inputs)

        embedded = embedding(inputs,
                             self.weight,
                             max_norm=self.max_norm,
                             norm_type=self.norm_type,
                             scale_grad_by_freq=self.scale_grad_by_freq,
                             sparse=self.sparse)
        # Now (if necessary) add back in the extra dimensions.
        embedded = util.uncombine_initial_dims(embedded, original_size)
        if self._use_fp16:
            embedded = embedded.half()
        embedded = self._dropout(embedded)
        return embedded
예제 #7
0
 def get_token_type_ids(tokens, sep_token):
     """
     Returns the token type ids, to be used in BERT's segment embeddings
     """
     assert (tokens.dim() in [
         2, 3
     ]), 'get_token_type_ids only supports {2,3}-dimensional sequences.'
     orig_size = tokens.size()
     if tokens.dim() == 3:
         tokens = util.combine_initial_dims(tokens)
     sep_token_mask = (tokens == sep_token).long()
     if sep_token_mask.nonzero().size(0) == tokens.size(0):
         return torch.zeros_like(tokens).view(
             orig_size
         )  # Use default BERT (all 0's) if there's 1 [SEP] per sample
     return (sep_token_mask.cumsum(-1) -
             sep_token_mask).clamp(max=1).view(orig_size)
예제 #8
0
    def forward(self, inputs):  # pylint: disable=arguments-differ
        # inputs may have extra dimensions (batch_size, d1, ..., dn, sequence_length),
        # but embedding expects (batch_size, sequence_length), so pass inputs to
        # util.combine_initial_dims (which is a no-op if there are no extra dimensions).
        # Remember the original size.
        original_size = inputs.size()
        inputs = util.combine_initial_dims(inputs)

        embedded = embedding(inputs, self.weight,
                             max_norm=self.max_norm,
                             norm_type=self.norm_type,
                             scale_grad_by_freq=self.scale_grad_by_freq,
                             sparse=self.sparse)

        # Now (if necessary) add back in the extra dimensions.
        embedded = util.uncombine_initial_dims(embedded, original_size)

        if self._projection:
            projection = self._projection
            for _ in range(embedded.dim() - 2):
                projection = TimeDistributed(projection)
            embedded = projection(embedded)
        return embedded
예제 #9
0
    def forward(self, inputs, input_type='entity'):  # pylint: disable=arguments-differ
        # Use to specified input type.
        if input_type == 'entity':
            weight = self.entity_weight
            element2id = self.entity2id
        elif input_type == 'predicate':
            weight = self.predicate_weight
            element2id = self.predicate2id
        else:
            raise Exception(
                "{} is not a valid input type, use 'entity' or 'predicate'.".
                format(x))

        # Find ids and add new ones if non-existent.
        max_len = max([len(input) for input in inputs])
        for i, input in enumerate(inputs):
            ids = []
            for key in input:
                if key not in element2id:
                    element2id[key] = len(weight)
                    weight = self.add_new_embedding(weight, input_type)
                ids.append(element2id[key])
            inputs[i] = ids + [0] * (max_len - len(input))
        inputs = torch.LongTensor(inputs)

        # Find embeddings of ids.
        original_size = inputs.size()
        inputs = util.combine_initial_dims(inputs)
        inputs = util.move_to_device(inputs, self.cuda_device)
        embedded = embedding(inputs,
                             weight,
                             max_norm=self.max_norm,
                             norm_type=self.norm_type,
                             scale_grad_by_freq=self.scale_grad_by_freq,
                             sparse=self.sparse)
        embedded = util.uncombine_initial_dims(embedded, original_size)
        return self.project(embedded, input_type)
예제 #10
0
    def test_combine_initial_dims(self):
        tensor = torch.randn(4, 10, 20, 17, 5)

        tensor2d = util.combine_initial_dims(tensor)
        assert list(tensor2d.size()) == [4 * 10 * 20 * 17, 5]
예제 #11
0
    def forward(self,
                phrase: Dict[str, torch.LongTensor],
                masked_labels: torch.LongTensor = None,
                label: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None) -> torch.Tensor:

        self._debug -= 1
        input_ids = phrase['tokens']
        masked_labels[(masked_labels == 0)] = -1

        batch_size = input_ids.size(0)
        num_choices = len(metadata[0]['choice_text_list'])

        question_mask = (input_ids != self._padding_value).long()

        # Segment ids are not used by RoBERTa
        if 'roberta' in self._pretrained_model or 'bert' in self._pretrained_model:
            if self._loss_on_all_vocab:
                outputs = self._transformer_model(
                    input_ids=util.combine_initial_dims(input_ids),
                    masked_lm_labels=masked_labels)
            else:
                outputs = self._transformer_model(
                    input_ids=util.combine_initial_dims(input_ids),
                    masked_lm_labels=masked_labels,
                    all_masked_index_ids=[
                        e['all_masked_index_ids'] for e in metadata
                    ],
                    label=label)
            loss, predictions = outputs[:2]
        elif 'xlnet' in self._pretrained_model:
            transformer_outputs = self._transformer_model(
                input_ids=util.combine_initial_dims(input_ids),
                token_type_ids=util.combine_initial_dims(segment_ids),
                attention_mask=util.combine_initial_dims(question_mask))
            cls_output = self.sequence_summary(transformer_outputs[0])
        else:
            assert (ValueError)

        output_dict = {}
        label_logits = torch.zeros(batch_size, num_choices)
        for e, example in enumerate(metadata):
            for c, choice in enumerate(example['all_masked_index_ids']):
                for t in choice:
                    label_logits[e, c] += predictions[e, t[0], t[1]]

            # TODO this is shortcut to get predictions fast..
            if self._predictions_file is not None and not self.training:
                with open(self._predictions_file, 'a') as f:
                    logits = label_logits[e, :].cpu().data.numpy().astype(
                        float)
                    pred_ind = np.argmax(logits)
                    f.write(json.dumps({'question_id': example['id'], \
                                        'phrase': example['question_text' ], \
                                        'choices': example['choice_text_list'] , \
                                        'logits': list(logits),
                                        'answer_ind': example['correct_answer_index'],
                                        'prediction': example['choice_text_list'][pred_ind],
                                        'is_correct': (example['correct_answer_index'] == pred_ind) * 1.0}) + '\n')

        self._accuracy(label_logits, label)
        output_dict["loss"] = loss

        if self._debug > 0:
            print(output_dict)
        return output_dict
예제 #12
0
    def forward(self,
                input_ids: torch.LongTensor,
                offsets: torch.LongTensor = None,
                token_type_ids: torch.LongTensor = None,
                history_encoding: torch.LongTensor = None,
                turn_encoding: torch.LongTensor = None,
                scenario_encoding: torch.LongTensor = None) -> torch.Tensor:
        """
        Parameters
        ----------
        input_ids : ``torch.LongTensor``
            The (batch_size, ..., max_sequence_length) tensor of wordpiece ids.
        offsets : ``torch.LongTensor``, optional
            The BERT embeddings are one per wordpiece. However it's possible/likely
            you might want one per original token. In that case, ``offsets``
            represents the indices of the desired wordpiece for each original token.
            Depending on how your token indexer is configured, this could be the
            position of the last wordpiece for each token, or it could be the position
            of the first wordpiece for each token.

            For example, if you had the sentence "Definitely not", and if the corresponding
            wordpieces were ["Def", "##in", "##ite", "##ly", "not"], then the input_ids
            would be 5 wordpiece ids, and the "last wordpiece" offsets would be [3, 4].
            If offsets are provided, the returned tensor will contain only the wordpiece
            embeddings at those positions, and (in particular) will contain one embedding
            per token. If offsets are not provided, the entire tensor of wordpiece embeddings
            will be returned.
        token_type_ids : ``torch.LongTensor``, optional
            If an input consists of two sentences (as in the BERT paper),
            tokens from the first sentence should have type 0 and tokens from
            the second sentence should have type 1.  If you don't provide this
            (the default BertIndexer doesn't) then it's assumed to be all 0s.
        """
        # pylint: disable=arguments-differ
        batch_size, full_seq_len = input_ids.size(0), input_ids.size(-1)
        initial_dims = list(input_ids.shape[:-1])

        # The embedder may receive an input tensor that has a sequence length longer than can
        # be fit. In that case, we should expect the wordpiece indexer to create padded windows
        # of length `self.max_pieces` for us, and have them concatenated into one long sequence.
        # E.g., "[CLS] I went to the [SEP] [CLS] to the store to [SEP] ..."
        # We can then split the sequence into sub-sequences of that length, and concatenate them
        # along the batch dimension so we effectively have one huge batch of partial sentences.
        # This can then be fed into BERT without any sentence length issues. Keep in mind
        # that the memory consumption can dramatically increase for large batches with extremely
        # long sentences.
        needs_split = full_seq_len > self.max_pieces
        last_window_size = 0
        if needs_split:
            input_ids = self.split_indices(input_ids)
            if token_type_ids is not None:
                token_type_ids = self.split_indices(token_type_ids)
            if history_encoding is not None:
                history_encoding = self.split_indices(history_encoding)
            if turn_encoding is not None:
                turn_encoding = self.split_indices(turn_encoding)
            if scenario_encoding is not None:
                scenario_encoding = self.split_indices(scenario_encoding)

        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)
        if history_encoding is None:
            history_encoding = torch.zeros_like(input_ids)
        if turn_encoding is None:
            turn_encoding = torch.zeros_like(input_ids)
        if scenario_encoding is None:
            scenario_encoding = torch.zeros_like(input_ids)

        input_mask = (input_ids != 0).long()

        # input_ids may have extra dimensions, so we reshape down to 2-d
        # before calling the BERT model and then reshape back at the end.
        all_encoder_layers, pooled_output = self.bert_model(
            input_ids=util.combine_initial_dims(input_ids),
            token_type_ids=util.combine_initial_dims(token_type_ids),
            history_encoding=util.combine_initial_dims(history_encoding),
            turn_encoding=util.combine_initial_dims(turn_encoding),
            scenario_encoding=util.combine_initial_dims(scenario_encoding),
            attention_mask=util.combine_initial_dims(input_mask))
        all_encoder_layers = torch.stack(all_encoder_layers)

        if needs_split:
            # First, unpack the output embeddings into one long sequence again
            unpacked_embeddings = torch.split(all_encoder_layers,
                                              batch_size,
                                              dim=1)
            unpacked_embeddings = torch.cat(unpacked_embeddings, dim=2)
            assert batch_size == 1 and token_type_ids.max() > 0
            num_question_tokens = token_type_ids[0].nonzero().size(0)
            select_indices = self.indices_to_select(full_seq_len,
                                                    num_question_tokens)
            initial_dims.append(len(select_indices))
            recombined_embeddings = unpacked_embeddings[:, :, select_indices]
        else:
            recombined_embeddings = all_encoder_layers

        # Recombine the outputs of all layers
        # (layers, batch_size * d1 * ... * dn, sequence_length, embedding_dim)
        # recombined = torch.cat(combined, dim=2)
        input_mask = (recombined_embeddings != 0).long()

        if self._scalar_mix is not None:
            mix = self._scalar_mix(recombined_embeddings, input_mask)
        else:
            mix = recombined_embeddings[-1]

        # At this point, mix is (batch_size * d1 * ... * dn, sequence_length, embedding_dim)

        if offsets is None:
            # Resize to (batch_size, d1, ..., dn, sequence_length, embedding_dim)
            dims = initial_dims if needs_split else input_ids.size()
            return util.uncombine_initial_dims(mix, dims)
        else:
            # offsets is (batch_size, d1, ..., dn, orig_sequence_length)
            offsets2d = util.combine_initial_dims(offsets)
            # now offsets is (batch_size * d1 * ... * dn, orig_sequence_length)
            zeros = torch.zeros(offsets2d.size(0),
                                1,
                                dtype=offsets2d.dtype,
                                device=offsets2d.device)
            offsets2d = torch.cat([zeros, offsets2d], dim=-1)
            # now offsets is (batch_size * d1 * ... * dn, orig_sequence_length + 1)
            range_vector = util.get_range_vector(
                offsets2d.size(0), device=util.get_device_of(mix)).unsqueeze(1)
            # selected embeddings is also (batch_size * d1 * ... * dn, orig_sequence_length + 1)
            selected_embeddings = mix[range_vector, offsets2d]

            return util.uncombine_initial_dims(selected_embeddings,
                                               offsets.size())
예제 #13
0
    def forward(self, tokens: torch.Tensor) -> torch.Tensor:
        if not ram_has_flag("EXE_ONCE.weighted_embedding"):
            print("The weighted embedding is working")
            import sys
            sys.stdout.flush()
            ram_set_flag("EXE_ONCE.weighted_embedding")
            
        if ram_has_flag("warm_mode", True) or ram_has_flag("weighted_off", True):
            embedded = embedding(
                util.combine_initial_dims(tokens),
                self.weight,
                padding_idx=self.padding_index,
                max_norm=self.max_norm,
                norm_type=self.norm_type,
                scale_grad_by_freq=self.scale_grad_by_freq,
                sparse=self.sparse,
            )
            embedded = util.uncombine_initial_dims(embedded, tokens.size())
            return embedded
        nbr_tokens, _coeff = self.hull.get_nbr_and_coeff(tokens.view(-1))

        # n_words x n_nbrs x dim
        embedded = embedding(
            nbr_tokens,
            self.weight,
            padding_idx=self.padding_index,
            max_norm=self.max_norm,
            norm_type=self.norm_type,
            scale_grad_by_freq=self.scale_grad_by_freq,
            sparse=self.sparse,
        )
        
        if not adv_utils.is_adv_mode():     
            coeff_logit = (_coeff + 1e-6).log()
        else:
            last_fw, last_bw = adv_utils.read_var_hook("coeff_logit")
            # coeff_logit = last_fw + adv_utils.recieve("step") * last_bw
            norm_last_bw = last_bw / (torch.norm(last_bw, dim=-1, keepdim=True) + 1e-6)
            coeff_logit = last_fw + adv_utils.recieve("step") * norm_last_bw
        
        coeff_logit = coeff_logit - coeff_logit.max(1, keepdim=True)[0]

        coeff_logit.requires_grad_()
        adv_utils.register_var_hook("coeff_logit", coeff_logit)
        coeff = F.softmax(coeff_logit, dim=1)

        # if adv_utils.is_adv_mode():
        #     last_coeff = F.softmax(last_fw, dim=1)
        #     new_points = (embedded[:20] * coeff[:20].unsqueeze(-1)).sum(-2)
        #     old_points = (embedded[:20] * last_coeff[:20].unsqueeze(-1)).sum(-2)
        #     step_size = (new_points - old_points).norm(dim=-1).mean()
        #     inner_size = (embedded[:20, 1:] - embedded[:20, :1]).norm(dim=-1).mean()
        #     print(round(inner_size.item(), 3), round(step_size.item(), 3))
        embedded = (embedded * coeff.unsqueeze(-1)).sum(-2)
        embedded = embedded.view(*tokens.size(), self.weight.size(1))
        if adv_utils.is_adv_mode():
            if ram_has_flag("adjust_point"):
                raw_embedded = embedding(
                    tokens,
                    self.weight,
                    padding_idx=self.padding_index,
                    max_norm=self.max_norm,
                    norm_type=self.norm_type,
                    scale_grad_by_freq=self.scale_grad_by_freq,
                    sparse=self.sparse,
                )
                delta = embedded.detach() - raw_embedded.detach()
                embedded = raw_embedded + delta
        return embedded
예제 #14
0
    def forward(self,
                    phrase: Dict[str, torch.LongTensor],
                    label: torch.LongTensor = None,
                    metadata: List[Dict[str, Any]] = None) -> torch.Tensor:

        self._debug -= 1
        input_ids = phrase['tokens']['token_ids']
        segment_ids = phrase['tokens']['type_ids']

        question_mask = (input_ids != self._padding_value).long()

        # Segment ids are not used by RoBERTa
        if 'roberta' in self._pretrained_model or 't5' in self._pretrained_model:
            transformer_outputs, pooled_output = self._transformer_model(input_ids=util.combine_initial_dims(input_ids),
                                                                         # token_type_ids=util.combine_initial_dims(segment_ids),
                                                                         attention_mask=util.combine_initial_dims(question_mask))
            cls_output = self._dropout(pooled_output)
        if 'albert' in self._pretrained_model:
            transformer_outputs, pooled_output = self._transformer_model(input_ids=util.combine_initial_dims(input_ids),
                                                                         # token_type_ids=util.combine_initial_dims(segment_ids),
                                                                         attention_mask=util.combine_initial_dims(question_mask))
            cls_output = self._dropout(pooled_output)
        elif 'xlnet' in self._pretrained_model:
            transformer_outputs = self._transformer_model(input_ids=util.combine_initial_dims(input_ids),
                                                          token_type_ids=util.combine_initial_dims(segment_ids),
                                                          attention_mask=util.combine_initial_dims(question_mask))
            cls_output = self.sequence_summary(transformer_outputs[0])

        elif 'bert' in self._pretrained_model:
            last_layer, pooled_output = self._transformer_model(input_ids=util.combine_initial_dims(input_ids),
                                                                token_type_ids=util.combine_initial_dims(segment_ids),
                                                                attention_mask=util.combine_initial_dims(question_mask))
            cls_output = self._dropout(pooled_output)
        else:
            assert (ValueError)

        label_logits = self._classifier(cls_output)

        output_dict = {}
        output_dict['label_logits'] = label_logits
        output_dict['label_probs'] = torch.nn.functional.softmax(label_logits, dim=1)
        output_dict['answer_index'] = label_logits.argmax(1)


        if label is not None:
            loss = self._loss(label_logits, label)
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss# TODO this is shortcut to get predictions fast..


            for e, example in enumerate(metadata):
                logits = sanitize(label_logits[e, :])
                label_probs = sanitize(output_dict['label_probs'][e, :])
                prediction = sanitize(output_dict['answer_index'][e])
                prediction_dict = {'id': example['id'], \
                                   'phrase': example['question_text'], \
                                   'context': example['context'], \
                                   'logits': logits,
                                   'label_probs': label_probs,
                                   'answer': example['correct_answer_index'],
                                   'prediction': prediction,
                                   'is_correct': (example['correct_answer_index'] == prediction) * 1.0}

                if 'skills' in example:
                    prediction_dict['skills'] = example['skills']
                if 'tags' in example:
                    prediction_dict['tags'] = example['tags']
                self._predictions.append(prediction_dict)

        #if self._predictions_file is not None:# and not self.training:
        #    with open(self._predictions_file, 'a') as f:
        #        for e, example in enumerate(metadata):
        #            logits = sanitize(label_logits[e, :])
        #            prediction = sanitize(output_dict['answer_index'][e])
        #            f.write(json.dumps({'id': example['id'], \
        #                                'phrase': example['question_text' ], \
        #                                'context': example['context'], \
        #                                'logits': logits,
        #                                'answer': example['correct_answer_index'],
        #                                'prediction': prediction,
        #                                'is_correct': (example['correct_answer_index'] == prediction) * 1.0}) + '\n')



        return output_dict
예제 #15
0
    def test_combine_initial_dims(self):
        tensor = torch.randn(4, 10, 20, 17, 5)

        tensor2d = util.combine_initial_dims(tensor)
        assert list(tensor2d.size()) == [4 * 10 * 20 * 17, 5]
예제 #16
0
    def forward(self,
                question: Dict[str, torch.LongTensor],
                segment_ids: torch.LongTensor = None,
                label: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None) -> torch.Tensor:

        self._debug -= 1
        input_ids = question['tokens']

        batch_size = input_ids.size(0)
        num_choices = input_ids.size(1)

        question_mask = (input_ids != self._padding_value).long()

        # Segment ids are not used by RoBERTa
        if 'roberta' in self._pretrained_model:
            transformer_outputs, pooled_output = self._transformer_model(
                input_ids=util.combine_initial_dims(input_ids),
                # token_type_ids=util.combine_initial_dims(segment_ids),
                attention_mask=util.combine_initial_dims(question_mask))
            cls_output = self._dropout(pooled_output)
        if 'albert' in self._pretrained_model:
            transformer_outputs, pooled_output = self._transformer_model(
                input_ids=util.combine_initial_dims(input_ids),
                # token_type_ids=util.combine_initial_dims(segment_ids),
                attention_mask=util.combine_initial_dims(question_mask))
            cls_output = self._dropout(pooled_output)
        elif 'xlnet' in self._pretrained_model:
            transformer_outputs = self._transformer_model(
                input_ids=util.combine_initial_dims(input_ids),
                token_type_ids=util.combine_initial_dims(segment_ids),
                attention_mask=util.combine_initial_dims(question_mask))
            cls_output = self.sequence_summary(transformer_outputs[0])

        elif 'bert' in self._pretrained_model:
            last_layer, pooled_output = self._transformer_model(
                input_ids=util.combine_initial_dims(input_ids),
                token_type_ids=util.combine_initial_dims(segment_ids),
                attention_mask=util.combine_initial_dims(question_mask))
            cls_output = self._dropout(pooled_output)
        else:
            assert (ValueError)

        label_logits = self._classifier(cls_output)
        label_logits = label_logits.view(-1, num_choices)

        output_dict = {}
        output_dict['label_logits'] = label_logits

        output_dict['label_probs'] = torch.nn.functional.softmax(label_logits,
                                                                 dim=1)
        output_dict['answer_index'] = label_logits.argmax(1)

        if label is not None:
            loss = self._loss(label_logits, label)
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss

        if self._debug > 0:
            print(output_dict)
        return output_dict
예제 #17
0
    def compute_logits_and_value(
        self,  # type: ignore
        question: Dict[str, torch.LongTensor],
        passage: Dict[str, torch.LongTensor],
        options: Dict[str, torch.LongTensor],
        sep_token: int,
        options_to_support: torch.FloatTensor = None,
        all_past_sent_choice_mask: torch.LongTensor = None,
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
        """
        Architecture-specific forward pass
        """
        # BERT-formatting input
        batch_size, num_options, _ = options['tokens'].size()
        pqo_tokens_list = []
        pqo_token_maxlens = []
        for i in range(num_options):
            qo_tokens = self.pack_sequences(question['tokens'],
                                            options['tokens'][:, i])
            pqo_tokens_list.append(
                self.pack_sequences(passage['tokens'], qo_tokens, sep_token))
            pqo_token_maxlens.append(pqo_tokens_list[i].size(-1))
        pqo_tokens = torch.zeros(batch_size,
                                 num_options,
                                 max(pqo_token_maxlens),
                                 dtype=torch.long,
                                 device=passage['tokens'].device)
        for i in range(num_options):
            pqo_tokens[:, i, :pqo_tokens_list[i].size(-1)] = pqo_tokens_list[i]
        pqo = self.tokens_to_bert_input(pqo_tokens, sep_token)

        # Condition debater on stance. Also add in past debater choices
        if not self.is_judge:
            pqo['token-type-ids'][:, :,
                                  0] = options_to_support  # Change segment embedding for [CLS] tokens only.
            if all_past_sent_choice_mask is not None:
                pqo_sent_chosen_mask = torch.zeros(
                    batch_size,
                    max(pqo_token_maxlens),
                    dtype=torch.long,
                    device=passage['tokens'].device)
                pqo_sent_chosen_mask[:, :all_past_sent_choice_mask.
                                     size(1)] = all_past_sent_choice_mask
                pqo_sent_chosen_mask = pqo_sent_chosen_mask.unsqueeze(
                    1).expand(-1, num_options, -1)
                pqo['token-type-ids'] = (pqo['token-type-ids'] +
                                         pqo_sent_chosen_mask).clamp(max=1)
                # other_embeddings = self._sent_chosen_embeddings(pqo_sent_chosen_mask).unsqueeze(1).expand(-1, num_options, -1, -1)
                # pqo['other-embeddings'] = other_embeddings.view(-1, other_embeddings.size(-2), other_embeddings.size(-1))

        hidden_pqo = self._text_field_embedder(pqo)
        if self.is_judge:
            pred_hidden_a = hidden_pqo[:, :, 0]
            option_logits = self._logit_predictor(pred_hidden_a).squeeze(-1)
            # Expose detached hidden states for theory-of-mind debaters
            self.pqo = pqo
            self.hidden_pqo = hidden_pqo.detach()
            self.pred_hidden_a = pred_hidden_a.detach()
            self.option_logits = option_logits.detach()
            return option_logits, None, None
        else:
            # Predict answer (auxiliary SL loss)
            option_logits = None
            if self._qa_loss_weight > 0:
                pred_hidden_a = hidden_pqo[:, :, 0]
                option_logits = self._logit_predictor(pred_hidden_a).squeeze(
                    -1)

            # Condition on option to support (again)
            agent_film_params = self._turn_film_gen(
                options_to_support.unsqueeze(-1))
            agent_gammas, agent_betas = torch.split(agent_film_params,
                                                    self._hidden_dim,
                                                    dim=-1)
            agent_hidden_pqo = self._film(
                hidden_pqo, 1. + agent_gammas,
                agent_betas) * pqo['mask'].float().unsqueeze(-1)

            # Process Judge hidden states
            if self.theory_of_mind:
                # Condition Judge states on Debater opinion (to highlight strong candidate sentences)
                cond_judge_hidden_pqo = self._film(
                    self.judge.hidden_pqo, 1. + agent_gammas,
                    agent_betas) * self.judge.pqo['mask'].float().unsqueeze(-1)
                # Align Judge states to Debater's full passage states
                shifted_judge_hidden_pqo = torch.zeros_like(agent_hidden_pqo)
                seq_lengths = util.get_lengths_from_binary_sequence_mask(
                    pqo['mask'])
                judge_seq_lengths = util.get_lengths_from_binary_sequence_mask(
                    self.judge.pqo['mask'])
                for i in range(batch_size):
                    for j in range(num_options):
                        shifted_judge_hidden_pqo[i, j, seq_lengths[i, j] - judge_seq_lengths[i, j]: seq_lengths[i, j]] = \
                            cond_judge_hidden_pqo[i, j, :judge_seq_lengths[i, j]]
                agent_hidden_pqo = self.final_blocks_input_proj(
                    torch.cat([agent_hidden_pqo, shifted_judge_hidden_pqo],
                              dim=-1))
                # Joint processing with transformer block
                extended_attention_mask = util.combine_initial_dims(
                    pqo['mask']).unsqueeze(1).unsqueeze(2)
                extended_attention_mask = extended_attention_mask.to(
                    dtype=next(self.final_blocks.parameters()).dtype)
                extended_attention_mask = (1.0 -
                                           extended_attention_mask) * -10000.0
                agent_hidden_pqo = self.final_blocks(
                    agent_hidden_pqo.view(batch_size * num_options, -1,
                                          self._hidden_dim),
                    extended_attention_mask,
                    output_all_encoded_layers=False)[-1]
                # Reshape and remask
                agent_hidden_pqo = agent_hidden_pqo.view(
                    batch_size, num_options, -1,
                    self._hidden_dim) * pqo['mask'].float().unsqueeze(-1)

            # Predict distribution over sentence actions
            tokenwise_values = self._value_head(agent_hidden_pqo.detach(
            ) if self._detach_value_head else agent_hidden_pqo).squeeze(-1)
            value, value_option = util.replace_masked_values(
                tokenwise_values, pqo['mask'], -1e7).max(-1)[0].max(-1)
            policy_logits = self._policy_head(agent_hidden_pqo).squeeze(
                -1).sum(1)
            return option_logits, policy_logits, value
예제 #18
0
    def forward(self,
                question: Dict[str, torch.LongTensor],
                segment_ids: torch.LongTensor = None,
                label: torch.LongTensor = None,
                binary_labels: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None) -> torch.Tensor:

        self._debug -= 1
        input_ids = question['tokens']['token_ids']
        batch_size = input_ids.size(0)
        num_choices = input_ids.size(1)
        num_binary_choices = 1

        # question_mask = (input_ids != self._padding_value).long()
        question_mask = question['tokens']['mask']

        if self._debug > 0:
            logger.info(f"batch_size = {batch_size}")
            logger.info(f"num_choices = {num_choices}")
            logger.info(f"question_mask = {question_mask}")
            logger.info(f"input_ids.size() = {input_ids.size()}")
            logger.info(f"input_ids = {input_ids}")
            logger.info(f"segment_ids = {segment_ids}")
            logger.info(f"label = {label}")
            logger.info(f"binary_labels = {binary_labels}")

        # Segment ids are not used by RoBERTa

        transformer_outputs = self._transformer_model(
            input_ids=util.combine_initial_dims(input_ids),
            # token_type_ids=util.combine_initial_dims(segment_ids),
            attention_mask=util.combine_initial_dims(question_mask))

        cls_output = transformer_outputs[0]

        if self._debug > 0:
            logger.info(f"cls_output = {cls_output}")

        label_logits = self._classifier(cls_output)
        label_logits_binary = label_logits.view(-1, num_binary_choices)
        label_logits = label_logits.view(-1, num_choices)

        output_dict = {}
        output_dict['label_logits'] = label_logits

        if self._binary_loss:
            output_dict['label_probs'] = self._sigmoid(label_logits)
        else:
            output_dict['label_probs'] = torch.nn.functional.softmax(
                label_logits, dim=1)
        output_dict['answer_index'] = label_logits.argmax(1)

        if self._binary_loss and binary_labels is not None:
            labels_float_reshaped = binary_labels.reshape(
                -1, num_binary_choices).to(label_logits.dtype)
            loss = self._loss(label_logits_binary, labels_float_reshaped)
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss
        elif label is not None:
            loss = self._loss(label_logits, label)
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss

        if self._debug > 0:
            logger.info(output_dict)
        return output_dict
예제 #19
0
    def forward(self,
                input_ids: torch.LongTensor,
                offsets: torch.LongTensor = None,
                token_type_ids: torch.LongTensor = None) -> torch.Tensor:
        """
        Parameters
        ----------
        input_ids : ``torch.LongTensor``
            The (batch_size, ..., max_sequence_length) tensor of wordpiece ids.
        offsets : ``torch.LongTensor``, optional
            The BERT embeddings are one per wordpiece. However it's possible/likely
            you might want one per original token. In that case, ``offsets``
            represents the indices of the desired wordpiece for each original token.
            Depending on how your token indexer is configured, this could be the
            position of the last wordpiece for each token, or it could be the position
            of the first wordpiece for each token.
            For example, if you had the sentence "Definitely not", and if the corresponding
            wordpieces were ["Def", "##in", "##ite", "##ly", "not"], then the input_ids
            would be 5 wordpiece ids, and the "last wordpiece" offsets would be [3, 4].
            If offsets are provided, the returned tensor will contain only the wordpiece
            embeddings at those positions, and (in particular) will contain one embedding
            per token. If offsets are not provided, the entire tensor of wordpiece embeddings
            will be returned.
        token_type_ids : ``torch.LongTensor``, optional
            If an input consists of two sentences (as in the BERT paper),
            tokens from the first sentence should have type 0 and tokens from
            the second sentence should have type 1.  If you don't provide this
            (the default BertIndexer doesn't) then it's assumed to be all 0s.
        """
        # pylint: disable=arguments-differ
        batch_size, full_seq_len = input_ids.size(0), input_ids.size(-1)
        initial_dims = list(input_ids.shape[:-1])

        # The embedder may receive an input tensor that has a sequence length longer than can
        # be fit. In that case, we should expect the wordpiece indexer to create padded windows
        # of length `self.max_pieces` for us, and have them concatenated into one long sequence.
        # E.g., "[CLS] I went to the [SEP] [CLS] to the store to [SEP] ..."
        # We can then split the sequence into sub-sequences of that length, and concatenate them
        # along the batch dimension so we effectively have one huge batch of partial sentences.
        # This can then be fed into BERT without any sentence length issues. Keep in mind
        # that the memory consumption can dramatically increase for large batches with extremely
        # long sentences.
        needs_split = full_seq_len > self.max_pieces
        last_window_size = 0
        if needs_split:
            # Split the flattened list by the window size, `max_pieces`
            split_input_ids = list(input_ids.split(self.max_pieces, dim=-1))

            # We want all sequences to be the same length, so pad the last sequence
            last_window_size = split_input_ids[-1].size(-1)
            padding_amount = self.max_pieces - last_window_size
            split_input_ids[-1] = F.pad(split_input_ids[-1],
                                        pad=[0, padding_amount],
                                        value=0)

            # Now combine the sequences along the batch dimension
            input_ids = torch.cat(split_input_ids, dim=0)

        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        input_mask = (input_ids != 0).long()

        # input_ids may have extra dimensions, so we reshape down to 2-d
        # before calling the BERT model and then reshape back at the end.
        all_encoder_layers, _ = self.bert_model(
            input_ids=util.combine_initial_dims(input_ids),
            token_type_ids=util.combine_initial_dims(token_type_ids),
            attention_mask=util.combine_initial_dims(input_mask))
        all_encoder_layers = torch.stack(all_encoder_layers)

        if needs_split:
            # First, unpack the output embeddings into one long sequence again
            unpacked_embeddings = torch.split(all_encoder_layers,
                                              batch_size,
                                              dim=1)
            unpacked_embeddings = torch.cat(unpacked_embeddings, dim=2)

            # Next, select indices of the sequence such that it will result in embeddings representing the original
            # sentence. To capture maximal context, the indices will be the middle part of each embedded window
            # sub-sequence (plus any leftover start and final edge windows), e.g.,
            #  0     1 2    3  4   5    6    7     8     9   10   11   12    13 14  15
            # "[CLS] I went to the very fine [SEP] [CLS] the very fine store to eat [SEP]"
            # with max_pieces = 8 should produce max context indices [2, 3, 4, 10, 11, 12] with additional start
            # and final windows with indices [0, 1] and [14, 15] respectively.

            # Find the stride as half the max pieces, ignoring the special start and end tokens
            # Calculate an offset to extract the centermost embeddings of each window
            stride = (self.max_pieces - self.start_tokens -
                      self.end_tokens) // 2
            stride_offset = stride // 2 + self.start_tokens

            first_window = list(range(stride_offset))

            max_context_windows = [
                i for i in range(full_seq_len) if stride_offset - 1 < i %
                self.max_pieces < stride_offset + stride
            ]

            final_window_start = full_seq_len - (
                full_seq_len % self.max_pieces) + stride_offset + stride
            final_window = list(range(final_window_start, full_seq_len))

            select_indices = first_window + max_context_windows + final_window

            initial_dims.append(len(select_indices))

            recombined_embeddings = unpacked_embeddings[:, :, select_indices]
        else:
            recombined_embeddings = all_encoder_layers

        # Recombine the outputs of all layers
        # (layers, batch_size * d1 * ... * dn, sequence_length, embedding_dim)
        # recombined = torch.cat(combined, dim=2)
        input_mask = (recombined_embeddings != 0).long()

        # At this point, mix is (batch_size * d1 * ... * dn, sequence_length, embedding_dim)

        if offsets is None:
            # Resize to (batch_size, d1, ..., dn, sequence_length, embedding_dim)
            dims = initial_dims if needs_split else input_ids.size()
            layers = util.uncombine_initial_dims(recombined_embeddings, dims)
        else:
            # offsets is (batch_size, d1, ..., dn, orig_sequence_length)
            offsets2d = util.combine_initial_dims(offsets)
            # now offsets is (batch_size * d1 * ... * dn, orig_sequence_length)
            range_vector = util.get_range_vector(
                offsets2d.size(0),
                device=util.get_device_of(recombined_embeddings)).unsqueeze(1)
            # selected embeddings is also (batch_size * d1 * ... * dn, orig_sequence_length)
            selected_embeddings = recombined_embeddings[:, range_vector,
                                                        offsets2d]

            layers = util.uncombine_initial_dims(selected_embeddings,
                                                 offsets.size())

        if self._scalar_mix is not None:
            return self._scalar_mix(layers, input_mask)
        elif self.combine_layers == "last":
            return layers[-1]
        else:
            return layers
예제 #20
0
    def forward(self,
                input_ids: torch.LongTensor,
                offsets: torch.LongTensor = None) -> torch.Tensor:
        """
        Parameters
        ----------
        input_ids : ``torch.LongTensor``
            The (batch_size, ..., max_sequence_length) tensor of wordpiece ids.
        offsets : ``torch.LongTensor``, optional
            The BERT embeddings are one per wordpiece. However it's possible/likely
            you might want one per original token. In that case, ``offsets``
            represents the indices of the desired wordpiece for each original token.
            Depending on how your token indexer is configured, this could be the
            position of the last wordpiece for each token, or it could be the position
            of the first wordpiece for each token.
            For example, if you had the sentence "Definitely not", and if the corresponding
            wordpieces were ["Def", "##in", "##ite", "##ly", "not"], then the input_ids
            would be 5 wordpiece ids, and the "last wordpiece" offsets would be [3, 4].
            If offsets are provided, the returned tensor will contain only the wordpiece
            embeddings at those positions, and (in particular) will contain one embedding
            per token. If offsets are not provided, the entire tensor of wordpiece embeddings
            will be returned.
        """
        # offsets对应indexer里面的,gector使用start法。
        # 即offsets记录着每个token的第一个wordpiece在整句wordpiece list中index
        batch_size, full_seq_len = input_ids.size(0), input_ids.size(-1)
        # 取出batch size
        initial_dims = list(input_ids.shape[:-1])
        # The embedder may receive an input tensor that has a sequence length longer than can
        # be fit. In that case, we should expect the wordpiece indexer to create padded windows
        # of length `self.max_pieces` for us, and have them concatenated into one long sequence.
        # E.g., "[CLS] I went to the [SEP] [CLS] to the store to [SEP] ..."
        # We can then split the sequence into sub-sequences of that length, and concatenate them
        # along the batch dimension so we effectively have one huge batch of partial sentences.
        # This can then be fed into BERT without any sentence length issues. Keep in mind
        # that the memory consumption can dramatically increase for large batches with extremely
        # long sentences.
        needs_split = full_seq_len > self.max_pieces
        last_window_size = 0
        if needs_split:
            # Split the flattened list by the window size, `max_pieces`
            # 按照最大piece切分输入
            split_input_ids = list(input_ids.split(self.max_pieces, dim=-1))

            # We want all sequences to be the same length, so pad the last sequence
            # 最后一列根据需要填充
            last_window_size = split_input_ids[-1].size(-1)
            padding_amount = self.max_pieces - last_window_size
            # 用0填充列
            split_input_ids[-1] = F.pad(split_input_ids[-1],
                                        pad=[0, padding_amount],
                                        value=0)

            # Now combine the sequences along the batch dimension
            # 沿着batch维拼接上
            input_ids = torch.cat(split_input_ids, dim=0)
        # 即为attention机制中的pad mask 防止注意力集中在填充的0上面
        input_mask = (input_ids != 0).long()
        # input_ids may have extra dimensions, so we reshape down to 2-d
        # before calling the BERT model and then reshape back at the end.
        '''
        模型forward的返回第一个值如下
        last_hidden_state (torch.FloatTensor of shape (batch_size, sequence_length, hidden_size)) 
        – Sequence of hidden-states at the output of the last layer of the model.
        '''
        all_encoder_layers = self.bert_model(
            input_ids=util.combine_initial_dims(input_ids),  # 转为二维
            attention_mask=util.combine_initial_dims(input_mask),
        )[0]
        # 确保是四维
        if len(all_encoder_layers[0].shape) == 3:
            all_encoder_layers = torch.stack(all_encoder_layers)
        elif len(all_encoder_layers[0].shape) == 2:
            all_encoder_layers = torch.unsqueeze(all_encoder_layers, dim=0)
        if needs_split:  # 这个操作是因为输入的seq长度大于maxpiece 现在要做的是首先将其拆分为一个list的多个元素,将截取的句子还原
            # First, unpack the output embeddings into one long sequence again 行拆分 列拼接
            # 这步做的是把数据拆沿着batch维度分成一个list,有原batch size个元素,将这些元素再拼接,就可以还原原数据
            unpacked_embeddings = torch.split(all_encoder_layers,
                                              batch_size,
                                              dim=1)
            # 然后要做的是将batch size个list按照最后embedding维进行拼接,最终形成一个维度为 1 * batch * full_seq_len * embed
            unpacked_embeddings = torch.cat(unpacked_embeddings, dim=2)

            # Next, select indices of the sequence such that it will result in embeddings representing the original
            # sentence. To capture maximal context, the indices will be the middle part of each embedded window
            # sub-sequence (plus any leftover start and final edge windows), e.g.,
            #  0     1 2    3  4   5    6    7     8     9   10   11   12    13 14  15
            # "[CLS] I went to the very fine [SEP] [CLS] the very fine store to eat [SEP]"
            # with max_pieces = 8 should produce max context indices [2, 3, 4, 10, 11, 12] with additional start
            # and final windows with indices [0, 1] and [14, 15] respectively.

            # Find the stride as half the max pieces, ignoring the special start and end tokens
            # Calculate an offset to extract the centermost embeddings of each window
            # 寻找最能代表文本(即中间位置的跨步stride)
            stride = (self.max_pieces - self.num_start_tokens -
                      self.num_end_tokens) // 2
            stride_offset = stride // 2 + self.num_start_tokens
            # 开头的部分
            first_window = list(range(stride_offset))
            # 选择中间的 stride个wordpiece

            max_context_windows = [
                i for i in range(full_seq_len) if stride_offset - 1 < i %
                self.max_pieces < stride_offset + stride
            ]

            # Lookback what's left, unless it's the whole self.max_pieces window lookback为应该往左边查看多少个token
            if full_seq_len % self.max_pieces == 0:
                lookback = self.max_pieces
            else:
                lookback = full_seq_len % self.max_pieces
            # 尾部
            final_window_start = full_seq_len - lookback + stride_offset + stride
            final_window = list(range(final_window_start, full_seq_len))
            # 头 + 中间index + 尾
            select_indices = first_window + max_context_windows + final_window
            # 这时候将最后一维加入list中
            initial_dims.append(len(select_indices))
            # 选择一个句子中的一部分token作为表示
            recombined_embeddings = unpacked_embeddings[:, :, select_indices]
        else:
            recombined_embeddings = all_encoder_layers

        # Recombine the outputs of all layers
        # (layers, batch_size * d1 * ... * dn, sequence_length, embedding_dim)
        # recombined = torch.cat(combined, dim=2)
        # mask同上
        input_mask = (recombined_embeddings != 0).long()

        if self._scalar_mix is not None:
            mix = self._scalar_mix(recombined_embeddings, input_mask)
        else:
            mix = recombined_embeddings[-1]

        # At this point, mix is (batch_size * d1 * ... * dn, sequence_length, embedding_dim)

        if offsets is None:
            # Resize to (batch_size, d1, ..., dn, sequence_length, embedding_dim)
            dims = initial_dims if needs_split else input_ids.size()
            return util.uncombine_initial_dims(mix, dims)
        else:
            # offsets 在gec_model 中的preprocess函数里面生成的,维度为[batch_size, seq_len]
            offsets2d = util.combine_initial_dims(offsets)
            # now offsets is [batch_size, seq_len]
            # rangevector返回一个tensor,如offsets2d.size(0)=5 返回 [0,1,2,3,4]
            range_vector = util.get_range_vector(
                offsets2d.size(0), device=util.get_device_of(mix)).unsqueeze(1)
            # selected embeddings is also (batch_size * d1 * ... * dn, orig_sequence_length)
            # 这里是给每个token选择offsets记录的它的wordpiece的idx去代替它
            selected_embeddings = mix[range_vector, offsets2d]
            # return the reshaped tensor of embeddings with shape (d1, ..., dn, orig_sequence_length, embedding_dim)
            # If original size is 1-d or 2-d, return it as is.
            # 这里直接返回selected embedings
            return util.uncombine_initial_dims(selected_embeddings,
                                               offsets.size())
예제 #21
0
    def forward(self,
                question: Dict[str, torch.LongTensor],
                choice1_indexes: List[int] = None,
                choice2_indexes: List[int] = None,
                label: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None) -> torch.Tensor:

        self._debug -= 1
        input_ids = question['bert']

        # input_ids.size() == (batch_size, num_pairs, max_sentence_length)
        batch_size, num_pairs, _ = question['bert'].size()
        question_mask = (input_ids != 0).long()

        if self._train_comparison_layer:
            assert num_pairs == self._num_choices * (self._num_choices - 1)

        # Segment ids
        real_segment_ids = question['bert-type-ids'].clone()
        # Change the last 'SEP' to belong to the second answer (for symmetry)
        last_seps = (real_segment_ids.roll(-1) == 2) & (real_segment_ids == 1)
        real_segment_ids[last_seps] = 2
        # Update segment ids so that they are '1' for answers and '0' for the question
        real_segment_ids = (real_segment_ids == 0) | (real_segment_ids == 2)
        real_segment_ids = real_segment_ids.long()

        # TODO: How to extract last token pooled output if batch size != 1
        assert batch_size == 1

        # Run model
        encoded_layers, first_vectors_pooled_output = self._bert_model(input_ids=util.combine_initial_dims(input_ids),
                                            token_type_ids=util.combine_initial_dims(real_segment_ids),
                                            attention_mask=util.combine_initial_dims(question_mask),
                                            output_all_encoded_layers=self._all_layers)

        if self._use_comparative_bert:
            last_vectors_pooled_output = self._extract_last_token_pooled_output(encoded_layers, question_mask)
        else:
            last_vectors_pooled_output = None
        if self._all_layers:
            mixed_layer = self._scalar_mix(encoded_layers, question_mask)
            first_vectors_pooled_output = self._bert_model.pooler(mixed_layer)

        # Apply dropout
        first_vectors_pooled_output = self._dropout(first_vectors_pooled_output)
        if self._use_comparative_bert:
            last_vectors_pooled_output = self._dropout(last_vectors_pooled_output)

        # Classify
        if not self._use_comparative_bert:
            pair_label_logits = self._classifier(first_vectors_pooled_output)
        else:
            if self._use_bilinear_classifier:
                pair_label_logits = self._classifier(first_vectors_pooled_output, last_vectors_pooled_output)
            else:
                all_pooled_output = torch.cat((first_vectors_pooled_output, last_vectors_pooled_output), 1)
                pair_label_logits = self._classifier(all_pooled_output)

        pair_label_logits = pair_label_logits.view(-1, num_pairs)

        pair_label_probs = torch.sigmoid(pair_label_logits)

        output_dict = {}
        pair_label_probs_flat = pair_label_probs.squeeze(1)
        output_dict['pair_label_probs'] = pair_label_probs_flat.view(-1, num_pairs)
        output_dict['pair_label_logits'] = pair_label_logits
        output_dict['choice1_indexes'] = choice1_indexes
        output_dict['choice2_indexes'] = choice2_indexes

        if not self._train_comparison_layer:
            if label is not None:
                label = label.unsqueeze(1)
                label = label.expand(-1, num_pairs)
                relevant_pairs = (choice1_indexes == label) | (choice2_indexes == label)
                relevant_probs = pair_label_probs[relevant_pairs]
                choice1_is_the_label = (choice1_indexes == label)[relevant_pairs]
                # choice1_is_the_label = choice1_is_the_label.type_as(relevant_logits)

                loss = self._loss(relevant_probs, choice1_is_the_label.float())
                self._accuracy(relevant_probs >= 0.5, choice1_is_the_label)
                output_dict["loss"] = loss

            return output_dict
        else:
            choice_logits = self._comparison_layer_2(self._comparison_layer_1_activation(self._comparison_layer_1(
                pair_label_probs)))
            output_dict['choice_logits'] = choice_logits
            output_dict['choice_probs'] = torch.softmax(choice_logits, 1)
            output_dict['predicted_choice'] = torch.argmax(choice_logits, 1)

            if label is not None:
                loss = self._loss(choice_logits, label)
                self._accuracy(choice_logits, label)
                output_dict["loss"] = loss

        return output_dict
예제 #22
0
    def forward(self,
                question: Dict[str, torch.LongTensor],
                segment_ids: torch.LongTensor = None,
                label: torch.LongTensor = None,
                metadata: List[Dict[str, Any]] = None) -> torch.Tensor:

        self._debug -= 1
        input_ids = question['tokens']

        input_ids = input_ids['token_ids']
        batch_size = input_ids.size(0)
        num_choices = input_ids.size(1)

        question_mask = (input_ids != self._padding_value).long()

        # Segment ids are not used by RoBERTa
        if 'roberta' in self._pretrained_model:
            output = self._transformer_model(
                input_ids=util.combine_initial_dims(input_ids),
                # token_type_ids=util.combine_initial_dims(segment_ids),
                attention_mask=util.combine_initial_dims(question_mask))
            last_layer = output.last_hidden_state
            pooled_output = output.pooler_output
            cls_output = self._dropout(pooled_output)
        if 'albert' in self._pretrained_model:
            assert False
            transformer_outputs, pooled_output = self._transformer_model(
                input_ids=util.combine_initial_dims(input_ids),
                # token_type_ids=util.combine_initial_dims(segment_ids),
                attention_mask=util.combine_initial_dims(question_mask))
            cls_output = self._dropout(pooled_output)
        elif 'xlnet' in self._pretrained_model:
            assert False
            transformer_outputs = self._transformer_model(
                input_ids=util.combine_initial_dims(input_ids),
                token_type_ids=util.combine_initial_dims(segment_ids),
                attention_mask=util.combine_initial_dims(question_mask))
            cls_output = self.sequence_summary(transformer_outputs[0])

        elif 'bert' in self._pretrained_model:
            output = self._transformer_model(
                input_ids=util.combine_initial_dims(input_ids),
                token_type_ids=util.combine_initial_dims(segment_ids),
                attention_mask=util.combine_initial_dims(question_mask))
            last_layer = output.last_hidden_state
            pooled_output = output.pooler_output
            cls_output = self._dropout(pooled_output)
        else:
            assert (ValueError)

        label_logits = self._classifier(cls_output)
        label_logits = label_logits.view(-1, num_choices)

        output_dict = {}
        output_dict['label_logits'] = label_logits

        output_dict['label_probs'] = torch.nn.functional.softmax(label_logits,
                                                                 dim=1)
        output_dict['answer_index'] = label_logits.argmax(1)

        # with open("age.txt", "a") as f:
        #     for i, example in enumerate(metadata):
        #         words = example["question_text"].split(" ")
        #         f.write(f'{words[1]} {words[10]} {example["correct_answer_index"] == output_dict["answer_index"][i]}\n')

        if label is not None:
            loss = self._loss(label_logits, label)
            self._accuracy(label_logits, label)
            output_dict["loss"] = loss

        if self._debug > 0:
            print(output_dict)
        return output_dict
예제 #23
0
    def forward(self,  # type: ignore
                question: Dict[str, torch.LongTensor],
                passage: Dict[str, torch.LongTensor],
                span_starts: torch.IntTensor = None,
                span_ends: torch.IntTensor = None,
                yesno_labels : torch.IntTensor = None,
                metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:

        batch_size, num_of_passage_tokens = passage['bert'].size()

        # Executing the BERT model on the word piece ids (input_ids)
        input_ids = passage['bert']
        token_type_ids = torch.zeros_like(input_ids)
        mask = (input_ids != 0).long()
        embedded_chunk, pooled_output = \
            self._text_field_embedder.token_embedder_bert.bert_model(input_ids=util.combine_initial_dims(input_ids),
                                                         token_type_ids=util.combine_initial_dims(token_type_ids),
                                                         attention_mask=util.combine_initial_dims(mask),
                                                         output_all_encoded_layers=False)

        # Just measuring some lengths and offsets to handle the converstion between tokens and word-pieces
        passage_length = embedded_chunk.size(1)
        mask_min_values, wordpiece_passage_lens = torch.min(mask, dim=1)
        wordpiece_passage_lens[mask_min_values == 1] = mask.shape[1]
        offset_min_values, token_passage_lens = torch.min(passage['bert-offsets'], dim=1)
        token_passage_lens[offset_min_values != 0] = passage['bert-offsets'].shape[1]
        bert_offsets = passage['bert-offsets'].cpu().numpy()

        # BERT for QA is a fully connected linear layer on top of BERT producing 2 vectors of
        # start and end spans.
        logits = self.qa_outputs(embedded_chunk)
        start_logits, end_logits = logits.split(1, dim=-1)
        span_start_logits = start_logits.squeeze(-1)
        span_end_logits = end_logits.squeeze(-1)

        # all input is preprocessed before farword is run, counting the yesno vocabulary
        # will indicate if yesno support is at all needed.
        if self.vocab.get_vocab_size("yesno_labels") > 1:
            yesno_logits = self.qa_yesno(torch.max(embedded_chunk, 1)[0])

        span_starts.clamp_(0, passage_length)
        span_ends.clamp_(0, passage_length)

        # moving to word piece indexes from token indexes of start and end span
        span_starts_list = [bert_offsets[i, span_starts[i]] if span_starts[i] != 0 else 0 for i in range(batch_size)]
        span_ends_list = [bert_offsets[i, span_ends[i]] if span_ends[i] != 0 else 0 for i in range(batch_size)]
        span_starts = torch.cuda.LongTensor(span_starts_list, device=span_end_logits.device) \
            if torch.cuda.is_available() else torch.LongTensor(span_starts_list)
        span_ends = torch.cuda.LongTensor(span_ends_list, device=span_end_logits.device) \
            if torch.cuda.is_available() else torch.LongTensor(span_ends_list)

        loss_fct = CrossEntropyLoss(ignore_index=passage_length)
        start_loss = loss_fct(start_logits.squeeze(-1), span_starts)
        end_loss = loss_fct(end_logits.squeeze(-1), span_ends)

        if self.vocab.get_vocab_size("yesno_labels") > 1 and yesno_labels is not None:
            yesno_loss = loss_fct(yesno_logits, yesno_labels)
            loss = (start_loss + end_loss + yesno_loss) / 3
        else:
            loss = (start_loss + end_loss) / 2

        output_dict: Dict[str, Any] = {}
        if loss == 0:
            # For evaluation purposes only!
            output_dict["loss"] = torch.cuda.FloatTensor([0], device=span_end_logits.device) \
                if torch.cuda.is_available() else torch.FloatTensor([0])
        else:
            output_dict["loss"] = loss

        # Compute F1 and preparing the output dictionary.
        output_dict['best_span_str'] = []
        output_dict['best_span_logit'] = []
        output_dict['cannot_answer_logit'] = []
        output_dict['yesno'] = []
        output_dict['yesno_logit'] = []
        output_dict['qid'] = []
        if span_starts is not None:
            output_dict['EM'] = []
            output_dict['f1'] = []

        # getting best span prediction for
        best_span = self._get_example_predications(span_start_logits, span_end_logits, self._max_span_length)
        best_span_cpu = best_span.detach().cpu().numpy()

        for instance_ind, instance_metadata in zip(range(batch_size), metadata):
            best_span_logit = span_start_logits.data.cpu().numpy()[instance_ind, best_span_cpu[instance_ind][0]] + \
                              span_end_logits.data.cpu().numpy()[instance_ind, best_span_cpu[instance_ind][1]]
            cannot_answer_logit = span_start_logits.data.cpu().numpy()[instance_ind, 0] + \
                              span_end_logits.data.cpu().numpy()[instance_ind, 0]

            if self.vocab.get_vocab_size("yesno_labels") > 1:
                yesno_maxind = np.argmax(yesno_logits[instance_ind].data.cpu().numpy())
                yesno_logit = yesno_logits[instance_ind, yesno_maxind].data.cpu().numpy()
                yesno_pred = self.vocab.get_token_from_index(yesno_maxind, namespace="yesno_labels")
            else:
                yesno_pred = 'no_yesno'
                yesno_logit = -30.0

            passage_str = instance_metadata['original_passage']
            offsets = instance_metadata['token_offsets']

            predicted_span = best_span_cpu[instance_ind]
            # In this version yesno if not "no_yesno" will be regarded as final answer before the spans are considered.
            if yesno_pred != 'no_yesno':
                best_span_string = yesno_pred
            else:
                if cannot_answer_logit + 0.9 > best_span_logit :
                    best_span_string = 'cannot_answer'
                else:
                    wordpiece_offsets = self.bert_offsets_to_wordpiece_offsets(bert_offsets[instance_ind][0:len(offsets)])
                    start_offset = offsets[wordpiece_offsets[predicted_span[0] if predicted_span[0] < len(wordpiece_offsets) \
                        else len(wordpiece_offsets)-1]][0]
                    end_offset = offsets[wordpiece_offsets[predicted_span[1] if predicted_span[1] < len(wordpiece_offsets) \
                        else len(wordpiece_offsets)-1]][1]
                    best_span_string = passage_str[start_offset:end_offset]

            output_dict['best_span_str'].append(best_span_string)
            output_dict['cannot_answer_logit'].append(cannot_answer_logit)
            output_dict['best_span_logit'].append(best_span_logit)
            output_dict['yesno'].append(yesno_pred)
            output_dict['yesno_logit'].append(yesno_logit)
            output_dict['qid'].append(instance_metadata['question_id'])

            # In AllenNLP prediction mode we have no gold answers, so let's check
            if span_starts is not None:
                yesno_label_ind = yesno_labels.data.cpu().numpy()[instance_ind]
                yesno_label = self.vocab.get_token_from_index(yesno_label_ind, namespace="yesno_labels")

                if yesno_label != 'no_yesno':
                    gold_answer_texts = [yesno_label]
                elif instance_metadata['cannot_answer']:
                    gold_answer_texts = ['cannot_answer']
                else:
                    gold_answer_texts = instance_metadata['answer_texts_list']

                f1_score = squad_eval.metric_max_over_ground_truths(squad_eval.f1_score, best_span_string, gold_answer_texts)
                EM_score = squad_eval.metric_max_over_ground_truths(squad_eval.exact_match_score, best_span_string, gold_answer_texts)
                self._official_f1(100 * f1_score)
                self._official_EM(100 * EM_score)
                output_dict['EM'].append(100 * EM_score)
                output_dict['f1'].append(100 * f1_score)


        return output_dict
예제 #24
0
    def forward(self,
                input_ids: torch.LongTensor,
                offsets: torch.LongTensor = None,
                token_type_ids: torch.LongTensor = None) -> torch.Tensor:
        """
        Parameters
        ----------
        input_ids : ``torch.LongTensor``
            The (batch_size, ..., max_sequence_length) tensor of wordpiece ids.
        offsets : ``torch.LongTensor``, optional
            The BERT embeddings are one per wordpiece. However it's possible/likely
            you might want one per original token. In that case, ``offsets``
            represents the indices of the desired wordpiece for each original token.
            Depending on how your token indexer is configured, this could be the
            position of the last wordpiece for each token, or it could be the position
            of the first wordpiece for each token.

            For example, if you had the sentence "Definitely not", and if the corresponding
            wordpieces were ["Def", "##in", "##ite", "##ly", "not"], then the input_ids
            would be 5 wordpiece ids, and the "last wordpiece" offsets would be [3, 4].
            If offsets are provided, the returned tensor will contain only the wordpiece
            embeddings at those positions, and (in particular) will contain one embedding
            per token. If offsets are not provided, the entire tensor of wordpiece embeddings
            will be returned.
        token_type_ids : ``torch.LongTensor``, optional
            If an input consists of two sentences (as in the BERT paper),
            tokens from the first sentence should have type 0 and tokens from
            the second sentence should have type 1.  If you don't provide this
            (the default BertIndexer doesn't) then it's assumed to be all 0s.
        """
        # pylint: disable=arguments-differ
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        input_mask = (input_ids != 0).long()

        # input_ids may have extra dimensions, so we reshape down to 2-d
        # before calling the BERT model and then reshape back at the end.
        all_encoder_layers, _ = self.bert_model(
            input_ids=util.combine_initial_dims(input_ids),
            token_type_ids=util.combine_initial_dims(token_type_ids),
            attention_mask=util.combine_initial_dims(input_mask))
        if self._scalar_mix is not None:
            mix = self._scalar_mix(all_encoder_layers, input_mask)
        else:
            mix = all_encoder_layers[-1]

        # At this point, mix is (batch_size * d1 * ... * dn, sequence_length, embedding_dim)

        if offsets is None:
            # Resize to (batch_size, d1, ..., dn, sequence_length, embedding_dim)
            return util.uncombine_initial_dims(mix, input_ids.size())
        else:
            # offsets is (batch_size, d1, ..., dn, orig_sequence_length)
            offsets2d = util.combine_initial_dims(offsets)
            # now offsets is (batch_size * d1 * ... * dn, orig_sequence_length)
            range_vector = util.get_range_vector(
                offsets2d.size(0), device=util.get_device_of(mix)).unsqueeze(1)
            # selected embeddings is also (batch_size * d1 * ... * dn, orig_sequence_length)
            selected_embeddings = mix[range_vector, offsets2d]

            return util.uncombine_initial_dims(selected_embeddings,
                                               offsets.size())
예제 #25
0
    def forward(self,
                input_ids: torch.LongTensor,
                offsets: torch.LongTensor = None,
                token_type_ids: torch.LongTensor = None) -> torch.Tensor:
        """
        Parameters
        ----------
        input_ids : ``torch.LongTensor``
            The (batch_size, ..., max_sequence_length) tensor of wordpiece ids.
        offsets : ``torch.LongTensor``, optional
            The BERT embeddings are one per wordpiece. However it's possible/likely
            you might want one per original token. In that case, ``offsets``
            represents the indices of the desired wordpiece for each original token.
            Depending on how your token indexer is configured, this could be the
            position of the last wordpiece for each token, or it could be the position
            of the first wordpiece for each token.

            For example, if you had the sentence "Definitely not", and if the corresponding
            wordpieces were ["Def", "##in", "##ite", "##ly", "not"], then the input_ids
            would be 5 wordpiece ids, and the "last wordpiece" offsets would be [3, 4].
            If offsets are provided, the returned tensor will contain only the wordpiece
            embeddings at those positions, and (in particular) will contain one embedding
            per token. If offsets are not provided, the entire tensor of wordpiece embeddings
            will be returned.
        token_type_ids : ``torch.LongTensor``, optional
            If an input consists of two sentences (as in the BERT paper),
            tokens from the first sentence should have type 0 and tokens from
            the second sentence should have type 1.  If you don't provide this
            (the default BertIndexer doesn't) then it's assumed to be all 0s.
        """
        # pylint: disable=arguments-differ
        if token_type_ids is None:
            token_type_ids = torch.zeros_like(input_ids)

        input_mask = (input_ids != 0).long()

        # input_ids may have extra dimensions, so we reshape down to 2-d
        # before calling the BERT model and then reshape back at the end.
        all_encoder_layers, _ = self.bert_model(input_ids=util.combine_initial_dims(input_ids),
                                                token_type_ids=util.combine_initial_dims(token_type_ids),
                                                attention_mask=util.combine_initial_dims(input_mask))
        if self._scalar_mix is not None:
            mix = self._scalar_mix(all_encoder_layers, input_mask)
        else:
            mix = all_encoder_layers[-1]

        # At this point, mix is (batch_size * d1 * ... * dn, sequence_length, embedding_dim)

        if offsets is None:
            # Resize to (batch_size, d1, ..., dn, sequence_length, embedding_dim)
            return util.uncombine_initial_dims(mix, input_ids.size())
        else:
            # offsets is (batch_size, d1, ..., dn, orig_sequence_length)
            offsets2d = util.combine_initial_dims(offsets)
            # now offsets is (batch_size * d1 * ... * dn, orig_sequence_length)
            range_vector = util.get_range_vector(offsets2d.size(0),
                                                 device=util.get_device_of(mix)).unsqueeze(1)
            # selected embeddings is also (batch_size * d1 * ... * dn, orig_sequence_length)
            selected_embeddings = mix[range_vector, offsets2d]

            return util.uncombine_initial_dims(selected_embeddings, offsets.size())