def _get_encoder(model_name): if model_name == 'bert' or model_name == 'roberta': sketchy_encoder = BERTEncoder( bert_config=self.bert_config, is_training=is_training, input_ids=split_placeholders['input_ids'], input_mask=split_placeholders['input_mask'], segment_ids=split_placeholders['segment_ids'], scope='bert', **kwargs) elif model_name == 'albert': sketchy_encoder = ALBERTEncoder( albert_config=self.bert_config, is_training=is_training, input_ids=split_placeholders['input_ids'], input_mask=split_placeholders['input_mask'], segment_ids=split_placeholders['segment_ids'], scope='bert', **kwargs) elif model_name == 'electra': sketchy_encoder = BERTEncoder( bert_config=self.bert_config, is_training=is_training, input_ids=split_placeholders['input_ids'], input_mask=split_placeholders['input_mask'], segment_ids=split_placeholders['segment_ids'], scope='electra', **kwargs) return sketchy_encoder
def _forward(self, is_training, split_placeholders, **kwargs): encoder = BERTEncoder( bert_config=self.bert_config, is_training=is_training, input_ids=split_placeholders['input_ids'], input_mask=split_placeholders['input_mask'], segment_ids=split_placeholders['segment_ids'], drop_pooler=self._drop_pooler, scope='bert', **kwargs) encoder_output = encoder.get_pooled_output() decoder = SemBERTDecoder( bert_config=self.bert_config, is_training=is_training, input_tensor=encoder_output, input_mask=split_placeholders['input_mask'], sem_features=split_placeholders['sem_features'], label_ids=split_placeholders['label_ids'], max_seq_length=self.max_seq_length, feature_size=len(self.sem_features), label_size=self.label_size, sample_weight=split_placeholders.get('sample_weight'), scope='cls/seq_relationship', **kwargs) (total_loss, losses, probs, preds) = decoder.get_forward_outputs() return (total_loss, losses, probs, preds)
def _forward(self, is_training, split_placeholders, **kwargs): if not is_training: return super()._forward(is_training, split_placeholders, **kwargs) aug_input_ids = tf.boolean_mask( split_placeholders['aug_input_ids'], mask=(1.0 - split_placeholders['is_supervised']), axis=0) aug_input_mask = tf.boolean_mask( split_placeholders['aug_input_mask'], mask=(1.0 - split_placeholders['is_supervised']), axis=0) aug_segment_ids = tf.boolean_mask( split_placeholders['aug_segment_ids'], mask=(1.0 - split_placeholders['is_supervised']), axis=0) input_ids = tf.concat([split_placeholders['input_ids'], aug_input_ids], axis=0) input_mask = tf.concat( [split_placeholders['input_mask'], aug_input_mask], axis=0) segment_ids = tf.concat( [split_placeholders['segment_ids'], aug_segment_ids], axis=0) encoder = BERTEncoder(bert_config=self.bert_config, is_training=is_training, input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, scope='bert', drop_pooler=self._drop_pooler, **kwargs) encoder_output = encoder.get_pooled_output() label_ids = split_placeholders['label_ids'] is_expanded = tf.zeros_like(label_ids, dtype=tf.float32) batch_size = util.get_shape_list(aug_input_ids)[0] aug_is_expanded = tf.ones((batch_size), dtype=tf.float32) is_expanded = tf.concat([is_expanded, aug_is_expanded], axis=0) decoder = UDADecoder( is_training=is_training, input_tensor=encoder_output, is_supervised=split_placeholders['is_supervised'], is_expanded=is_expanded, label_ids=label_ids, label_size=self.label_size, sample_weight=split_placeholders.get('sample_weight'), scope='cls/seq_relationship', global_step=self._global_step, num_train_steps=self.total_steps, uda_softmax_temp=self._uda_softmax_temp, uda_confidence_thresh=self._uda_confidence_thresh, tsa_schedule=self._tsa_schedule, **kwargs) (total_loss, losses, probs, preds) = decoder.get_forward_outputs() return (total_loss, losses, probs, preds)
def _forward(self, is_training, split_placeholders, **kwargs): encoder = BERTEncoder(bert_config=self.bert_config, is_training=is_training, input_ids=split_placeholders['input_ids'], input_mask=split_placeholders['input_mask'], segment_ids=split_placeholders['segment_ids'], scope='electra', drop_pooler=True, **kwargs) encoder_output = encoder.get_sequence_output() decoder = MRCDecoder( is_training=is_training, input_tensor=encoder_output, label_ids=split_placeholders['label_ids'], sample_weight=split_placeholders.get('sample_weight'), scope='mrc', **kwargs) (total_loss, losses, probs, preds) = decoder.get_forward_outputs() return (total_loss, losses, probs, preds)
def _forward(self, is_training, split_placeholders, **kwargs): encoder = BERTEncoder(bert_config=self.bert_config, is_training=is_training, input_ids=split_placeholders['input_ids'], input_mask=split_placeholders['input_mask'], segment_ids=split_placeholders['segment_ids'], scope='electra', drop_pooler=True, **kwargs) encoder_output = encoder.get_pooled_output() decoder = BinaryCLSDecoder( is_training=is_training, input_tensor=encoder_output, label_ids=split_placeholders['label_ids'], label_size=self.label_size, sample_weight=split_placeholders.get('sample_weight'), label_weight=self.label_weight, scope='cls/seq_relationship', name='cls', **kwargs) (total_loss, losses, probs, preds) = decoder.get_forward_outputs() return (total_loss, losses, probs, preds)
def _forward(self, is_training, split_placeholders, **kwargs): encoder = BERTEncoder(bert_config=self.bert_config, is_training=is_training, input_ids=split_placeholders['input_ids'], input_mask=split_placeholders['input_mask'], segment_ids=split_placeholders['segment_ids'], scope='bert', drop_pooler=self._drop_pooler, **kwargs) decoder = BERTDecoder( bert_config=self.bert_config, is_training=is_training, encoder=encoder, masked_lm_positions=split_placeholders['masked_lm_positions'], masked_lm_ids=split_placeholders['masked_lm_ids'], masked_lm_weights=split_placeholders['masked_lm_weights'], sample_weight=split_placeholders.get('sample_weight'), scope_lm='cls/predictions', **kwargs) (total_loss, losses, probs, preds) = decoder.get_forward_outputs() return (total_loss, losses, probs, preds)