示例#1
0
 def _get_encoder(model_name):
     if model_name == 'bert' or model_name == 'roberta':
         sketchy_encoder = BERTEncoder(
             bert_config=self.bert_config,
             is_training=is_training,
             input_ids=split_placeholders['input_ids'],
             input_mask=split_placeholders['input_mask'],
             segment_ids=split_placeholders['segment_ids'],
             scope='bert',
             **kwargs)
     elif model_name == 'albert':
         sketchy_encoder = ALBERTEncoder(
             albert_config=self.bert_config,
             is_training=is_training,
             input_ids=split_placeholders['input_ids'],
             input_mask=split_placeholders['input_mask'],
             segment_ids=split_placeholders['segment_ids'],
             scope='bert',
             **kwargs)
     elif model_name == 'electra':
         sketchy_encoder = BERTEncoder(
             bert_config=self.bert_config,
             is_training=is_training,
             input_ids=split_placeholders['input_ids'],
             input_mask=split_placeholders['input_mask'],
             segment_ids=split_placeholders['segment_ids'],
             scope='electra',
             **kwargs)
     return sketchy_encoder
示例#2
0
    def _forward(self, is_training, split_placeholders, **kwargs):

        encoder = BERTEncoder(
            bert_config=self.bert_config,
            is_training=is_training,
            input_ids=split_placeholders['input_ids'],
            input_mask=split_placeholders['input_mask'],
            segment_ids=split_placeholders['segment_ids'],
            drop_pooler=self._drop_pooler,
            scope='bert',
            **kwargs)
        encoder_output = encoder.get_pooled_output()
        decoder = SemBERTDecoder(
            bert_config=self.bert_config,
            is_training=is_training,
            input_tensor=encoder_output,
            input_mask=split_placeholders['input_mask'],
            sem_features=split_placeholders['sem_features'],
            label_ids=split_placeholders['label_ids'],
            max_seq_length=self.max_seq_length,
            feature_size=len(self.sem_features),
            label_size=self.label_size,
            sample_weight=split_placeholders.get('sample_weight'),
            scope='cls/seq_relationship',
            **kwargs)
        (total_loss, losses, probs, preds) = decoder.get_forward_outputs()
        return (total_loss, losses, probs, preds)
示例#3
0
文件: uda.py 项目: zhongyunuestc/unif
    def _forward(self, is_training, split_placeholders, **kwargs):

        if not is_training:
            return super()._forward(is_training, split_placeholders, **kwargs)

        aug_input_ids = tf.boolean_mask(
            split_placeholders['aug_input_ids'],
            mask=(1.0 - split_placeholders['is_supervised']),
            axis=0)
        aug_input_mask = tf.boolean_mask(
            split_placeholders['aug_input_mask'],
            mask=(1.0 - split_placeholders['is_supervised']),
            axis=0)
        aug_segment_ids = tf.boolean_mask(
            split_placeholders['aug_segment_ids'],
            mask=(1.0 - split_placeholders['is_supervised']),
            axis=0)
        input_ids = tf.concat([split_placeholders['input_ids'], aug_input_ids],
                              axis=0)
        input_mask = tf.concat(
            [split_placeholders['input_mask'], aug_input_mask], axis=0)
        segment_ids = tf.concat(
            [split_placeholders['segment_ids'], aug_segment_ids], axis=0)
        encoder = BERTEncoder(bert_config=self.bert_config,
                              is_training=is_training,
                              input_ids=input_ids,
                              input_mask=input_mask,
                              segment_ids=segment_ids,
                              scope='bert',
                              drop_pooler=self._drop_pooler,
                              **kwargs)
        encoder_output = encoder.get_pooled_output()

        label_ids = split_placeholders['label_ids']
        is_expanded = tf.zeros_like(label_ids, dtype=tf.float32)
        batch_size = util.get_shape_list(aug_input_ids)[0]
        aug_is_expanded = tf.ones((batch_size), dtype=tf.float32)
        is_expanded = tf.concat([is_expanded, aug_is_expanded], axis=0)
        decoder = UDADecoder(
            is_training=is_training,
            input_tensor=encoder_output,
            is_supervised=split_placeholders['is_supervised'],
            is_expanded=is_expanded,
            label_ids=label_ids,
            label_size=self.label_size,
            sample_weight=split_placeholders.get('sample_weight'),
            scope='cls/seq_relationship',
            global_step=self._global_step,
            num_train_steps=self.total_steps,
            uda_softmax_temp=self._uda_softmax_temp,
            uda_confidence_thresh=self._uda_confidence_thresh,
            tsa_schedule=self._tsa_schedule,
            **kwargs)
        (total_loss, losses, probs, preds) = decoder.get_forward_outputs()
        return (total_loss, losses, probs, preds)
示例#4
0
    def _forward(self, is_training, split_placeholders, **kwargs):

        encoder = BERTEncoder(bert_config=self.bert_config,
                              is_training=is_training,
                              input_ids=split_placeholders['input_ids'],
                              input_mask=split_placeholders['input_mask'],
                              segment_ids=split_placeholders['segment_ids'],
                              scope='electra',
                              drop_pooler=True,
                              **kwargs)
        encoder_output = encoder.get_sequence_output()
        decoder = MRCDecoder(
            is_training=is_training,
            input_tensor=encoder_output,
            label_ids=split_placeholders['label_ids'],
            sample_weight=split_placeholders.get('sample_weight'),
            scope='mrc',
            **kwargs)
        (total_loss, losses, probs, preds) = decoder.get_forward_outputs()
        return (total_loss, losses, probs, preds)
示例#5
0
    def _forward(self, is_training, split_placeholders, **kwargs):

        encoder = BERTEncoder(bert_config=self.bert_config,
                              is_training=is_training,
                              input_ids=split_placeholders['input_ids'],
                              input_mask=split_placeholders['input_mask'],
                              segment_ids=split_placeholders['segment_ids'],
                              scope='electra',
                              drop_pooler=True,
                              **kwargs)
        encoder_output = encoder.get_pooled_output()
        decoder = BinaryCLSDecoder(
            is_training=is_training,
            input_tensor=encoder_output,
            label_ids=split_placeholders['label_ids'],
            label_size=self.label_size,
            sample_weight=split_placeholders.get('sample_weight'),
            label_weight=self.label_weight,
            scope='cls/seq_relationship',
            name='cls',
            **kwargs)
        (total_loss, losses, probs, preds) = decoder.get_forward_outputs()
        return (total_loss, losses, probs, preds)
示例#6
0
    def _forward(self, is_training, split_placeholders, **kwargs):

        encoder = BERTEncoder(bert_config=self.bert_config,
                              is_training=is_training,
                              input_ids=split_placeholders['input_ids'],
                              input_mask=split_placeholders['input_mask'],
                              segment_ids=split_placeholders['segment_ids'],
                              scope='bert',
                              drop_pooler=self._drop_pooler,
                              **kwargs)
        decoder = BERTDecoder(
            bert_config=self.bert_config,
            is_training=is_training,
            encoder=encoder,
            masked_lm_positions=split_placeholders['masked_lm_positions'],
            masked_lm_ids=split_placeholders['masked_lm_ids'],
            masked_lm_weights=split_placeholders['masked_lm_weights'],
            sample_weight=split_placeholders.get('sample_weight'),
            scope_lm='cls/predictions',
            **kwargs)
        (total_loss, losses, probs, preds) = decoder.get_forward_outputs()
        return (total_loss, losses, probs, preds)