def _init_classifier(self, hidden_size): if "pretraining" in self.config.training_head_type: self.classifier = BertPreTrainingHeads(self.bert_config) if "vqa" in self.config.training_head_type: self.dropout = nn.Dropout(self.bert_config.hidden_dropout_prob) self.answer_space_size = 3129 self.classifier = nn.Sequential( BertPredictionHeadTransform(self.bert_config), nn.Linear(self.bert_config.hidden_size, self.answer_space_size), ) # self.classifier = nn.Linear(self.bert_config.hidden_size, # self.answer_space_size) elif "vizwiz" in self.config.training_head_type: self.dropout = nn.Dropout(self.bert_config.hidden_dropout_prob) self.answer_space_size = 7371 self.classifier = nn.Sequential( BertPredictionHeadTransform(self.bert_config), nn.Linear(self.bert_config.hidden_size, self.answer_space_size), ) # self.classifier = nn.Linear(self.bert_config.hidden_size, # self.answer_space_size) elif self.config.training_head_type == "visual_entailment": self.dropout = nn.Dropout(self.bert_config.hidden_dropout_prob) self.classifier = nn.Sequential( BertPredictionHeadTransform(self.bert_config), nn.Linear(self.bert_config.hidden_size, 3), )
def __init__(self, config: BertConfig): super().__init__() self.transform1 = BertPredictionHeadTransform(config) self.transform2 = BertPredictionHeadTransform(config) self.decoder = nn.Linear(config.hidden_size, config.hidden_size)
def __init__(self, args, base_model_name='bert-base-uncased'): super(DialogBERT, self).__init__() if args.language == 'chinese': base_model_name = 'bert-base-chinese' self.tokenizer = BertTokenizer.from_pretrained(base_model_name, cache_dir='./cache/') if args.model_size == 'tiny': self.encoder_config = BertConfig(vocab_size=30522, hidden_size=256, num_hidden_layers=6, num_attention_heads=2, intermediate_size=1024) self.utt_encoder = BertForPreTraining(self.encoder_config) elif args.model_size == 'small': self.encoder_config = BertConfig(vocab_size=30522, hidden_size=512, num_hidden_layers=8, num_attention_heads=4, intermediate_size=2048) self.utt_encoder = BertForPreTraining(self.encoder_config) else: self.encoder_config = BertConfig.from_pretrained( base_model_name, cache_dir='./cache/') self.utt_encoder = BertForPreTraining.from_pretrained( base_model_name, config=self.encoder_config, cache_dir='./cache/') self.context_encoder = BertModel( self.encoder_config) # context encoder: encode context to vector self.mlm_mode = 'mse' # 'mdn', 'mse' if self.mlm_mode == 'mdn': self.context_mlm_trans = MixtureDensityNetwork( self.encoder_config.hidden_size, self.encoder_config.hidden_size, 3) else: self.context_mlm_trans = BertPredictionHeadTransform( self.encoder_config ) # transform context hidden states back to utterance encodings self.dropout = nn.Dropout(self.encoder_config.hidden_dropout_prob) self.context_order_trans = SelfSorting(self.encoder_config.hidden_size) # self.context_order_trans = MLP(self.encoder_config.hidden_size, '200-200-200', 1) self.decoder_config = deepcopy(self.encoder_config) self.decoder_config.is_decoder = True self.decoder_config.add_cross_attention = True self.decoder = BertLMHeadModel(self.decoder_config)
def __init__(self, config): super(BertImagePredictionHead, self).__init__() self.transform = BertPredictionHeadTransform(config) # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. self.decoder = nn.Linear(config.hidden_size, config.target_size)
def build(self): self.text_processor = registry.get(self._datasets[0] + "_text_processor") self.vocab = self.text_processor.vocab self.word_embedding = self.vocab.get_embedding( torch.nn.Embedding, freeze=False, embedding_dim=self.config.text_embedding.embedding_dim) self.segment_embeddings = nn.Embedding(self.config.num_segment_type, self.config.hidden_size) self.cls_project = nn.Linear(self.config.text_embedding.embedding_dim, self.config.hidden_size) self.lstm = nn.LSTM(**self.config.lstm) self.lstm_proj = nn.Linear(self.config.hidden_size * 2, self.config.hidden_size) self.img_encoder = ImageClevrEncoder(self.config) self.img_pos_emb = nn.Linear(2, self.config.hidden_size) self.LayerNorm = nn.LayerNorm(self.config.hidden_size, eps=self.config.layer_norm_eps) self.dropout = nn.Dropout(self.config.hidden_dropout_prob) self.bert_config = BertConfig.from_dict( OmegaConf.to_container(self.config, resolve=True)) self.transformer = BertEncoder(self.bert_config) self.pooler = BertPooler(self.bert_config) self.classifier = nn.Sequential( BertPredictionHeadTransform(self.config), nn.Linear(self.config.hidden_size, self.config.num_labels), ) self.head_mask = [None for _ in range(self.config.num_hidden_layers)]
def __init__(self, config: BertConfig, bert_model_embedding_weights, position_embedding_size=200): super().__init__() self.position_embeddings = nn.Embedding(config.max_position_embeddings, position_embedding_size) self.pos_emb_proj = nn.Linear(position_embedding_size, config.hidden_size) self.transform = BertPredictionHeadTransform(config) # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. self.decoder = nn.Linear(bert_model_embedding_weights.size(1), bert_model_embedding_weights.size(0), bias=False) self.decoder.weight = bert_model_embedding_weights self.bias = nn.Parameter( torch.zeros(bert_model_embedding_weights.size(0))) # position_ids (1, len position emb) is contiguous in memory and exported when serialized self.register_buffer( "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
def __init__(self, config, extra_config): super().__init__() self.config = config self.output_attentions = self.config.output_attentions self.output_hidden_states = self.config.output_hidden_states self.pooler_strategy = self.config.get("pooler_strategy", "default") # Graph input params self.feed_graph_to_vb = extra_config["feed_graph_to_vb"] self.graph_node_hid_dim = extra_config["node_hid_dim"] self.graph_feed_mode = extra_config["feed_mode"] self.graph_topk = extra_config["topk_ans_feed"] # If doing graph, make a graph embedding layer if self.feed_graph_to_vb: self.graph_embedding = nn.Sequential( nn.Linear(self.graph_node_hid_dim, config.hidden_size), nn.LayerNorm(config.hidden_size, eps=1e-12), nn.Dropout(config.hidden_dropout_prob), # hidden_dropout_prb ) # If bert_model_name is not specified, you will need to specify # all of the required parameters for BERTConfig and a pretrained # model won't be loaded self.bert_model_name = self.config.get("bert_model_name", None) self.bert_config = BertConfig.from_dict( OmegaConf.to_container(self.config, resolve=True) ) if self.bert_model_name is None or self.bert_model_name == "nopretrain": self.bert = VisualBERTBase( self.bert_config, visual_embedding_dim=self.config.visual_embedding_dim, embedding_strategy=self.config.embedding_strategy, bypass_transformer=self.config.bypass_transformer, output_attentions=self.config.output_attentions, output_hidden_states=self.config.output_hidden_states, ) else: self.bert = VisualBERTBase.from_pretrained( self.config.bert_model_name, config=self.bert_config, cache_dir=os.path.join( get_mmf_cache_dir(), "distributed_{}".format(-1) ), visual_embedding_dim=self.config.visual_embedding_dim, embedding_strategy=self.config.embedding_strategy, bypass_transformer=self.config.bypass_transformer, output_attentions=self.config.output_attentions, output_hidden_states=self.config.output_hidden_states, ) self.training_head_type = self.config.training_head_type self.dropout = nn.Dropout(self.bert.config.hidden_dropout_prob) if self.config.training_head_type == "nlvr2": self.bert.config.hidden_size *= 2 self.classifier = nn.Sequential(BertPredictionHeadTransform(self.bert.config)) self.init_weights()
def __init__(self, config, decoder_weight): super(BertLMPredictionHead, self).__init__() self.transform = BertPredictionHeadTransform(config) # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. self.decoder_weight = decoder_weight self.bias = nn.Parameter(torch.zeros(config.vocab_size))
def from_pretrained(self, model_dir): self.encoder_config = BertConfig.from_pretrained(model_dir) self.tokenizer = BertTokenizer.from_pretrained( path.join(model_dir, 'tokenizer'), do_lower_case=args.do_lower_case) self.utt_encoder = BertForPreTraining.from_pretrained( path.join(model_dir, 'utt_encoder')) self.context_encoder = BertForSequenceClassification.from_pretrained( path.join(model_dir, 'context_encoder')) self.context_mlm_trans = BertPredictionHeadTransform( self.encoder_config) self.context_mlm_trans.load_state_dict( torch.load(path.join(model_dir, 'context_mlm_trans.pkl'))) self.context_order_trans = SelfSorting(self.encoder_config.hidden_size) self.context_order_trans.load_state_dict( torch.load(path.join(model_dir, 'context_order_trans.pkl'))) self.decoder_config = BertConfig.from_pretrained(model_dir) self.decoder = BertLMHeadModel.from_pretrained( path.join(model_dir, 'decoder'))
def __init__(self, config, *args, **kwargs): super().__init__() self.config = config self.bert = MMBTBase(config, *args, **kwargs) self.encoder_config = self.bert.encoder_config self.dropout = nn.Dropout(self.encoder_config.hidden_dropout_prob) self.classifier = nn.Sequential( BertPredictionHeadTransform(self.encoder_config), nn.Linear(self.encoder_config.hidden_size, self.config.num_labels), )
def __init__(self, config): super(BertForImageCaptioning, self).__init__(config) self.config = config self.bert = BertImgModel(config) self.transform = BertPredictionHeadTransform(config) bert_embedding_weight = self.bert.embeddings.word_embeddings.weight self.decoder = nn.Linear(bert_embedding_weight.size(1), bert_embedding_weight.size(0), bias=False) self.loss = nn.CrossEntropyLoss(reduction='mean') self.drop_worst_ratio = 0.2
def __init(self, config=None, *args, **kwargs): super().__init__(*args, **kwargs) if config is None: from transformers.configuration_bert import BertConfig config = BertConfig.from_pretrained('bert-base-uncased') assert config.hidden_size == self.in_dim from transformers.modeling_bert import BertPredictionHeadTransform self.module = nn.Sequential( nn.Dropout(config.hidden_dropout_prob), BertPredictionHeadTransform(config), nn.Linear(self.in_dim, self.out_dim), )
def __init__(self, config): super(NodeConstructOutputLayer, self).__init__() self.transform = BertPredictionHeadTransform(config) # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. self.decoder = nn.Linear(config.hidden_size, config.x_size, bias=False) self.bias = nn.Parameter(torch.zeros(config.x_size)) # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias
def __init__(self, config: Config, *args, **kwargs): super().__init__(config, *args, **kwargs) # Head modules self.pooler = BertPooler(self.config) self.classifier = nn.Sequential( nn.Dropout(self.config.hidden_dropout_prob), BertPredictionHeadTransform(self.config), nn.Linear(self.config.hidden_size, self.config.num_labels), ) self.num_labels = self.config.num_labels self.hidden_size = self.config.hidden_size
def __init__(self, config): super().__init__() self.transform = BertPredictionHeadTransform(config) self.visual_losses = config.visual_losses # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. self.decoder_dict = nn.ModuleDict({ key: nn.Linear(config.hidden_size, config.visual_loss_config[key][0]) for key in self.visual_losses })
def build_heads(self): """Initialize the classifier head. It takes the output of the transformer encoder and passes it through a pooler (we use the pooler from BERT model), then dropout, BertPredictionHeadTransform (which is a linear layer, followed by activation and layer norm) and lastly a linear layer projecting the hidden output to classification labels. """ self.classifier = nn.Sequential( BertPooler(self.transformer_config), nn.Dropout(self.transformer_config.hidden_dropout_prob), BertPredictionHeadTransform(self.transformer_config), nn.Linear(self.transformer_config.hidden_size, self.config.num_labels), )
def __init__(self, config, bert_model_embedding_weights): super().__init__() self.transform = BertPredictionHeadTransform(config) # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. self.decoder = nn.Linear( bert_model_embedding_weights.size(1), bert_model_embedding_weights.size(0), bias=False, ) self.decoder.weight = bert_model_embedding_weights self.bias = nn.Parameter( torch.zeros(bert_model_embedding_weights.size(0)))
def __init__(self, config, *args, **kwargs): super().__init__() self.config = config self.bert = MMBTBase(config, *args, **kwargs) self.encoder_config = self.bert.encoder_config self.num_labels = self.config.num_labels self.output_hidden_states = self.encoder_config.output_hidden_states self.output_attentions = self.encoder_config.output_attentions self.fused_feature_only = self.config.fused_feature_only self.dropout = nn.Dropout(self.encoder_config.hidden_dropout_prob) self.classifier = nn.Sequential( BertPredictionHeadTransform(self.encoder_config), nn.Linear(self.encoder_config.hidden_size, self.config.num_labels), )
def __init__(self, config): super().__init__() self.config = config self.output_attentions = self.config.output_attentions self.output_hidden_states = self.config.output_hidden_states self.pooler_strategy = self.config.get("pooler_strategy", "default") # If bert_model_name is not specified, you will need to specify # all of the required parameters for BERTConfig and a pretrained # model won't be loaded self.bert_model_name = getattr(self.config, "bert_model_name", None) self.bert_config = BertConfig.from_dict( OmegaConf.to_container(self.config, resolve=True) ) if self.bert_model_name is None: self.bert = VisualBERTBase( self.bert_config, visual_embedding_dim=self.config.visual_embedding_dim, embedding_strategy=self.config.embedding_strategy, bypass_transformer=self.config.bypass_transformer, output_attentions=self.config.output_attentions, output_hidden_states=self.config.output_hidden_states, ) else: self.bert = VisualBERTBase.from_pretrained( self.config.bert_model_name, config=self.bert_config, cache_dir=os.path.join( get_multimodelity_cache_dir(), "distributed_{}".format(-1) ), visual_embedding_dim=self.config.visual_embedding_dim, embedding_strategy=self.config.embedding_strategy, bypass_transformer=self.config.bypass_transformer, output_attentions=self.config.output_attentions, output_hidden_states=self.config.output_hidden_states, ) self.training_head_type = self.config.training_head_type self.num_labels = self.config.num_labels self.dropout = nn.Dropout(self.bert.config.hidden_dropout_prob) if self.config.training_head_type == "nlvr2": self.bert.config.hidden_size *= 2 self.classifier = nn.Sequential( BertPredictionHeadTransform(self.bert.config), nn.Linear(self.bert.config.hidden_size, self.config.num_labels), ) self.init_weights()
def __init__(self, **kwargs): super().__init__() self.config = kwargs self.output_attentions = self.config['output_attentions'] self.output_hidden_states = self.config['output_hidden_states'] self.pooler_strategy = self.config.get('pooler_strategy', 'default') # If bert_model_name is not specified, you will need to specify # all of the required parameters for BERTConfig and a pretrained # model won't be loaded self.bert_model_name = self.config['bert_model_name'] self.bert_config = BertConfig.from_dict(self.config) if self.bert_model_name is None: self.bert = VisualBERTBase( self.bert_config, visual_embedding_dim=self.config['visual_embedding_dim'], embedding_strategy=self.config['embedding_strategy'], bypass_transformer=self.config['bypass_transformer'], output_attentions=self.config['output_attentions'], output_hidden_states=self.config['output_hidden_states'], ) else: from imix.utils.config import ToExpanduser cache_dir = os.path.join('~/.cache/torch', 'transformers') cache_dir = ToExpanduser.modify_path(cache_dir) self.bert = VisualBERTBase.from_pretrained( self.config['bert_model_name'], config=self.bert_config, cache_dir=cache_dir, visual_embedding_dim=self.config['visual_embedding_dim'], embedding_strategy=self.config['embedding_strategy'], bypass_transformer=self.config['bypass_transformer'], output_attentions=self.config['output_attentions'], output_hidden_states=self.config['output_hidden_states'], ) self.training_head_type = self.config['training_head_type'] self.num_labels = self.config['num_labels'] self.dropout = nn.Dropout(self.bert.config.hidden_dropout_prob) if self.config['training_head_type'] == 'nlvr2': self.bert.config.hidden_size *= 2 self.classifier = nn.Sequential( BertPredictionHeadTransform(self.bert.config), nn.Linear(self.bert.config.hidden_size, self.config['num_labels']), ) self.init_weights()
def build_heads(self): """Initialize the classifier head. It takes the output of the transformer encoder and passes it through a pooler (we use the pooler from BERT model), then dropout, BertPredictionHeadTransform (which is a linear layer, followed by activation and layer norm) and lastly a linear layer projecting the hidden output to classification labels. """ transformer_config = self.backend.get_config() if self.config.training_head_type == "classification": self.pooler = BertPooler(transformer_config) self.classifier = nn.Sequential( nn.Dropout(transformer_config.hidden_dropout_prob), BertPredictionHeadTransform(transformer_config), nn.Linear(transformer_config.hidden_size, self.config.num_labels), ) elif self.config.training_head_type == "pretraining": self.cls = BertOnlyMLMHead(transformer_config) self.vocab_size = transformer_config.vocab_size
def __init__(self, config): super().__init__() self.config = config # self.output_attentions = self.config.output_attentions self.output_hidden_states = self.config.output_hidden_states self.bert = VisualBERTBase( self.config, visual_embedding_dim=self.config.visual_embedding_dim, ) self.num_labels = self.config.num_labels self.dropout = nn.Dropout(self.bert.config.hidden_dropout_prob) # if self.config.training_head_type == "nlvr2": # self.bert.config.hidden_size *= 2 self.classifier = nn.Sequential( BertPredictionHeadTransform(self.bert.config), nn.Linear(self.bert.config.hidden_size, self.config.num_labels), ) self.init_weights()
def __init__(self, config, bert_model_embedding_weights, hidden_size=768): super(BertLMPredictionHead, self).__init__() config.hidden_size = hidden_size self.transform = BertPredictionHeadTransform(config) # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. if hidden_size == 768: self.decoder = nn.Linear( bert_model_embedding_weights.size(1), bert_model_embedding_weights.size(0), bias=False, ) # print(self.decoder.weight.shape) # 30522, 768 self.decoder.weight = bert_model_embedding_weights else: self.decoder = nn.Linear( hidden_size, bert_model_embedding_weights.size(0), bias=False, ) self.bias = nn.Parameter( torch.zeros(bert_model_embedding_weights.size(0)))
def __init__(self, config, num_labels): super(PredictionHead, self).__init__() self.transform = BertPredictionHeadTransform(config) self.decoder = nn.Linear(config.hidden_size, num_labels) self.bias = nn.Parameter(torch.zeros(num_labels))
def __init__(self, config): super().__init__() self.transform = BertPredictionHeadTransform(config) self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
class DialogBERT(nn.Module): '''Hierarchical BERT for dialog v5 with two features: - Masked context utterances prediction with direct MSE matching of their vectors - Energy-based Utterance order prediction: A novel approach to shuffle the context and predict the original order with distributed order prediction''' # TODO: 1. Enhance sorting net # 2. Better data loader for permutation ((avoid returning perm_id and use max(pos_ids) instead, def __init__(self, args, base_model_name='bert-base-uncased'): super(DialogBERT, self).__init__() if args.language == 'chinese': base_model_name = 'bert-base-chinese' self.tokenizer = BertTokenizer.from_pretrained(base_model_name, cache_dir='./cache/') if args.model_size == 'tiny': self.encoder_config = BertConfig(vocab_size=30522, hidden_size=256, num_hidden_layers=6, num_attention_heads=2, intermediate_size=1024) self.utt_encoder = BertForPreTraining(self.encoder_config) elif args.model_size == 'small': self.encoder_config = BertConfig(vocab_size=30522, hidden_size=512, num_hidden_layers=8, num_attention_heads=4, intermediate_size=2048) self.utt_encoder = BertForPreTraining(self.encoder_config) else: self.encoder_config = BertConfig.from_pretrained( base_model_name, cache_dir='./cache/') self.utt_encoder = BertForPreTraining.from_pretrained( base_model_name, config=self.encoder_config, cache_dir='./cache/') self.context_encoder = BertModel( self.encoder_config) # context encoder: encode context to vector self.mlm_mode = 'mse' # 'mdn', 'mse' if self.mlm_mode == 'mdn': self.context_mlm_trans = MixtureDensityNetwork( self.encoder_config.hidden_size, self.encoder_config.hidden_size, 3) else: self.context_mlm_trans = BertPredictionHeadTransform( self.encoder_config ) # transform context hidden states back to utterance encodings self.dropout = nn.Dropout(self.encoder_config.hidden_dropout_prob) self.context_order_trans = SelfSorting(self.encoder_config.hidden_size) # self.context_order_trans = MLP(self.encoder_config.hidden_size, '200-200-200', 1) self.decoder_config = deepcopy(self.encoder_config) self.decoder_config.is_decoder = True self.decoder_config.add_cross_attention = True self.decoder = BertLMHeadModel(self.decoder_config) def init_weights(self, m): # Initialize Linear Weight for GAN if isinstance(m, nn.Linear): m.weight.data.uniform_(-0.08, 0.08) #nn.init.xavier_normal_(m.weight) nn.init.constant_(m.bias, 0.) @classmethod def from_pretrained(self, model_dir): self.encoder_config = BertConfig.from_pretrained(model_dir) self.tokenizer = BertTokenizer.from_pretrained( path.join(model_dir, 'tokenizer'), do_lower_case=args.do_lower_case) self.utt_encoder = BertForPreTraining.from_pretrained( path.join(model_dir, 'utt_encoder')) self.context_encoder = BertForSequenceClassification.from_pretrained( path.join(model_dir, 'context_encoder')) self.context_mlm_trans = BertPredictionHeadTransform( self.encoder_config) self.context_mlm_trans.load_state_dict( torch.load(path.join(model_dir, 'context_mlm_trans.pkl'))) self.context_order_trans = SelfSorting(self.encoder_config.hidden_size) self.context_order_trans.load_state_dict( torch.load(path.join(model_dir, 'context_order_trans.pkl'))) self.decoder_config = BertConfig.from_pretrained(model_dir) self.decoder = BertLMHeadModel.from_pretrained( path.join(model_dir, 'decoder')) def save_pretrained(self, output_dir): def save_module(model, save_path): torch.save(model_to_save.state_dict(), save_path) def make_list_dirs(dir_list): for dir_ in dir_list: os.makedirs(dir_, exist_ok=True) make_list_dirs([ path.join(output_dir, name) for name in ['tokenizer', 'utt_encoder', 'context_encoder', 'decoder'] ]) model_to_save = self.module if hasattr(self, 'module') else self model_to_save.encoder_config.save_pretrained( output_dir) # Save configuration file model_to_save.tokenizer.save_pretrained( path.join(output_dir, 'tokenizer')) model_to_save.utt_encoder.save_pretrained( path.join(output_dir, 'utt_encoder')) model_to_save.context_encoder.save_pretrained( path.join(output_dir, 'context_encoder')) save_module(model_to_save.context_mlm_trans, path.join(output_dir, 'context_mlm_trans.pkl')) save_module(model_to_save.context_order_trans, path.join(output_dir, 'context_order_trans.pkl')) model_to_save.decoder_config.save_pretrained( output_dir) # Save configuration file model_to_save.decoder.save_pretrained(path.join(output_dir, 'decoder')) def utt_encoding(self, context, utts_attn_mask): batch_size, max_ctx_len, max_utt_len = context.size( ) #context: [batch_size x diag_len x max_utt_len] utts = context.view( -1, max_utt_len) # [(batch_size*diag_len) x max_utt_len] utts_attn_mask = utts_attn_mask.view(-1, max_utt_len) _, utts_encodings, *_ = self.utt_encoder.bert(utts, utts_attn_mask) utts_encodings = utts_encodings.view(batch_size, max_ctx_len, -1) return utts_encodings def context_encoding(self, context, utts_attn_mask, ctx_attn_mask): #with torch.no_grad(): utt_encodings = self.utt_encoding(context, utts_attn_mask) context_hiddens, pooled_output, *_ = self.context_encoder( None, ctx_attn_mask, None, None, None, utt_encodings) # context_hiddens:[batch_size x ctx_len x dim]; pooled_output=[batch_size x dim] return context_hiddens, pooled_output def train_dialog_flow(self, context, context_utts_attn_mask, context_attn_mask, context_lm_targets, context_position_perm_id, context_position_ids, response): """ only train the dialog flow model """ self.context_encoder.train() # set the module in training mode. self.context_mlm_trans.train() context_hiddens, context_encoding = self.context_encoding( context, context_utts_attn_mask, context_attn_mask) lm_pred_encodings = self.context_mlm_trans( self.dropout(context_hiddens)) context_lm_targets[context_lm_targets == -100] = 0 ctx_lm_mask = context_lm_targets.sum(2) if (ctx_lm_mask > 0).sum() == 0: ctx_lm_mask[0, 0] = 1 lm_pred_encodings = lm_pred_encodings[ctx_lm_mask > 0] context_lm_targets = context_lm_targets[ctx_lm_mask > 0] context_lm_targets_attn_mask = context_utts_attn_mask[ctx_lm_mask > 0] with torch.no_grad(): _, lm_tgt_encodings, *_ = self.utt_encoder.bert( context_lm_targets, context_lm_targets_attn_mask) loss_ctx_mlm = MSELoss()(lm_pred_encodings, lm_tgt_encodings) # [num_selected_utts x dim] # context order prediction if isinstance(self.context_order_trans, SelfSorting): sorting_scores = self.context_order_trans(context_hiddens, context_attn_mask) else: sorting_scores = self.context_order_trans(context_hiddens) sorting_pad_mask = context_attn_mask == 0 sorting_pad_mask[ context_position_perm_id < 1] = True # exclude single-turn and unshuffled dialogs loss_ctx_uop = listNet(sorting_scores, context_position_ids, sorting_pad_mask) #loss_ctx_uop = listMLE(sorting_scores, context_position_ids, sorting_pad_mask) loss = loss_ctx_mlm + loss_ctx_uop return { 'loss': loss, 'loss_ctx_mlm': loss_ctx_lm, 'loss_ctx_uop': loss_ctx_uop } def train_decoder(self, context, context_utts_attn_mask, context_attn_mask, context_lm_targets, context_position_perm_id, context_position_ids, response): """ only train the decoder """ self.decoder.train() with torch.no_grad(): context_hiddens, context_encoding = self.context_encoding( context, context_utts_attn_mask, context_attn_mask) ## train decoder dec_input, dec_target = response[:, :-1].contiguous( ), response[:, 1:].clone() dec_output, *_ = self.decoder( dec_input, dec_input.ne(self.tokenizer.pad_token_id).long(), None, None, None, None, encoder_hidden_states=context_hiddens, encoder_attention_mask=context_attn_mask, ) batch_size, seq_len, vocab_size = dec_output.size() dec_target[response[:, 1:] == self.tokenizer.pad_token_id] = -100 dec_target[context_position_perm_id > 1] == -100 # ignore responses whose contexts are shuffled loss_decoder = CrossEntropyLoss()(dec_output.view(-1, vocab_size), dec_target.view(-1)) results = {'loss': loss_decoder, 'loss_decoder': loss_decoder} return results def forward(self, context, context_utts_attn_mask, context_attn_mask, context_mlm_targets, context_position_perm_id, context_position_ids, response): self.train() batch_size, max_ctx_len, max_utt_len = context.size( ) #context: [batch_size x diag_len x max_utt_len] context_hiddens, context_encoding = self.context_encoding( context, context_utts_attn_mask, context_attn_mask) ## train dialog flow modeling context_mlm_targets[context_mlm_targets == -100] = 0 ctx_mlm_mask = context_mlm_targets.sum(2) #[batch_size x num_utts] if (ctx_mlm_mask > 0).sum() == 0: ctx_mlm_mask[0, 0] = 1 ctx_mlm_mask = ctx_mlm_mask > 0 with torch.no_grad(): _, mlm_tgt_encodings, *_ = self.utt_encoder.bert( context_mlm_targets[ctx_mlm_mask], context_utts_attn_mask[ctx_mlm_mask]) if self.mlm_mode == 'mdn': # mixture density network mlm_pred_pi, mlm_pred_normal = self.context_mlm_trans( self.dropout(context_hiddens[ctx_mlm_mask])) loss_ctx_mlm = self.context_mlm_trans.loss(mlm_pred_pi, mlm_pred_normal, mlm_tgt_encodings) else: # simply mean square loss mlm_pred_encodings = self.context_mlm_trans( self.dropout(context_hiddens[ctx_mlm_mask])) loss_ctx_mlm = MSELoss()( mlm_pred_encodings, mlm_tgt_encodings) # [num_selected_utts x dim] # context order prediction if isinstance(self.context_order_trans, SelfSorting): sorting_scores = self.context_order_trans(context_hiddens, context_attn_mask) else: sorting_scores = self.context_order_trans(context_hiddens) sorting_pad_mask = context_attn_mask == 0 sorting_pad_mask[ context_position_perm_id < 1] = True # exclude single-turn and unshuffled dialogs loss_ctx_uop = listNet(sorting_scores, context_position_ids, sorting_pad_mask) #loss_ctx_uop = listMLE(sorting_scores, context_position_ids, sorting_pad_mask) ## train decoder dec_input, dec_target = response[:, :-1].contiguous( ), response[:, 1:].clone() dec_output, *_ = self.decoder( dec_input, dec_input.ne(self.tokenizer.pad_token_id).long(), None, None, None, None, encoder_hidden_states=context_hiddens, encoder_attention_mask=context_attn_mask, ) batch_size, seq_len, vocab_size = dec_output.size() dec_target[response[:, 1:] == self.tokenizer.pad_token_id] = -100 dec_target[context_position_perm_id > 1] = -100 # ignore responses whose context was shuffled loss_decoder = CrossEntropyLoss()(dec_output.view(-1, vocab_size), dec_target.view(-1)) loss = loss_ctx_mlm + loss_ctx_uop + loss_decoder results = { 'loss': loss, 'loss_ctx_mlm': loss_ctx_mlm, 'loss_ctx_uop': loss_ctx_uop, 'loss_decoder': loss_decoder } return results def validate(self, context, context_utts_attn_mask, context_attn_mask, context_lm_targets, context_position_perm_id, context_position_ids, response): results = self.train_decoder(context, context_utts_attn_mask, context_attn_mask, context_lm_targets, context_position_perm_id, context_position_ids, response) return results['loss'].item() def generate(self, input_batch, max_len=30, num_samples=1, mode='sample'): self.eval() device = next(self.parameters()).device context, context_utts_attn_mask, context_attn_mask = [ t.to(device) for t in input_batch[:3] ] ground_truth = input_batch[6].numpy() context_hiddens, context_encoding = self.context_encoding( context, context_utts_attn_mask, context_attn_mask) generated = torch.zeros( (num_samples, 1), dtype=torch.long, device=device).fill_(self.tokenizer.cls_token_id) # [batch_sz x 1] (1=seq_len) sample_lens = torch.ones((num_samples, 1), dtype=torch.long, device=device) len_inc = torch.ones((num_samples, 1), dtype=torch.long, device=device) for _ in range(max_len): outputs, *_ = self.decoder( generated, generated.ne(self.tokenizer.pad_token_id).long(), None, None, None, None, encoder_hidden_states=context_hiddens, encoder_attention_mask=context_attn_mask, ) # [batch_size x seq_len x vocab_size] next_token_logits = outputs[:, -1, :] / self.decoder_config.temperature # repetition penalty from CTRL (https://arxiv.org/abs/1909.05858) for i in range(num_samples): for _ in set(generated[i].tolist()): next_token_logits[ i, _] /= self.decoder_config.repetition_penalty filtered_logits = top_k_top_p_filtering( next_token_logits, top_k=self.decoder_config.top_k, top_p=self.decoder_config.top_p) if mode == 'greedy': # greedy sampling: next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(-1) else: next_token = torch.multinomial(torch.softmax(filtered_logits, dim=-1), num_samples=num_samples) next_token[len_inc == 0] = self.tokenizer.pad_token_id generated = torch.cat((generated, next_token), dim=1) len_inc = len_inc * ( next_token != self.tokenizer.sep_token_id).long( ) # stop incresing length (set 0 bit) when EOS is encountered if len_inc.sum() < 1: break sample_lens = sample_lens + len_inc # to numpy sample_words = generated.data.cpu().numpy() sample_lens = sample_lens.data.cpu().numpy() context = context.data.cpu().numpy() return sample_words, sample_lens, context, ground_truth # nparray: [repeat x seq_len]
def __init__(self, config, src_len): super(LayoutlmSPLMPredictionHead, self).__init__() self.transform = BertPredictionHeadTransform(config) self.bias = nn.Parameter(torch.zeros(src_len))
def build(self): # build the base model (based on DETR) self.unit_base_model = UniTBaseModel(self.config.base_args) def keep_only_backbone_params(model_state_dict): keys = list(model_state_dict.keys()) for k in keys: if "backbone" not in k: model_state_dict.pop(k) ckpt_path = self.config.base_ckpt_path if ckpt_path != "": logger.info(f"initializing base model (UniT) from {ckpt_path}") if ckpt_path.startswith("https"): base_checkpoint = torch.hub.load_state_dict_from_url( ckpt_path, check_hash=True) else: base_checkpoint = torch.load(ckpt_path) if self.config.base_ckpt_load_backbone_only: keep_only_backbone_params(base_checkpoint["model"]) self.unit_base_model.load_state_dict(base_checkpoint["model"], strict=False) else: self.unit_base_model.load_state_dict(base_checkpoint["model"], strict=True) # build the text encoder (BERT) self.bert_model = TransformerEncoder(self.config.base_args.bert_config) detr_hidden_dim = self.config.base_args.decoder_hidden_dim bert_config = deepcopy(self.bert_model.config) self.bert_projection = nn.Linear(bert_config.hidden_size, detr_hidden_dim) self.bert_pos_projection = nn.Linear(bert_config.hidden_size, detr_hidden_dim) self.classifiers = nn.ModuleDict() self.task_embeddings_lang = nn.Identity() if self.config.base_args.use_task_embedding_in_lang_encoder: self.task_embeddings_lang = nn.Embedding(self.config.max_task_num, bert_config.hidden_size) bert_config.hidden_size = detr_hidden_dim # build the task-specific output heads self.class_embeds = nn.ModuleDict() self.bbox_embeds = nn.ModuleDict() self.det_losses = nn.ModuleDict() for dataset_name in self.config.base_args.num_queries.get( "detection", []): num_cls = self.config.heads["detection"][dataset_name][ "num_classes"] self.class_embeds[dataset_name] = nn.Linear( detr_hidden_dim, num_cls + 1) self.bbox_embeds[dataset_name] = MLP(detr_hidden_dim, detr_hidden_dim, 4, 3) attr_head = None if self.config.heads["detection"][dataset_name]["use_attr"]: attr_head = AttributeHead( num_cls, self.config.base_args.attribute_class_num, detr_hidden_dim) self.det_losses[dataset_name] = build_detection_loss( self.config.base_args, num_cls, attr_head) vl_classifiers = nn.ModuleDict() for dataset_name in self.config.base_args.num_queries.get("vl", []): vl_classifiers[dataset_name] = nn.Sequential( BertPredictionHeadTransform(bert_config), nn.Linear( bert_config.hidden_size, self.config.heads["vl"][dataset_name]["num_labels"], ), ) self.classifiers["vl"] = vl_classifiers self.dropout = nn.Dropout(bert_config.hidden_dropout_prob) glue_classifiers = nn.ModuleDict() for dataset_name in self.config.base_args.num_queries.get("glue", []): glue_classifiers[dataset_name] = nn.Sequential( BertPredictionHeadTransform(bert_config), nn.Linear( bert_config.hidden_size, self.config.heads["glue"][dataset_name]["num_labels"], ), ) self.classifiers["glue"] = glue_classifiers self.loss_calculation_fn = {} self.loss_calculation_fn["detection"] = self.detection_loss_calculation self.loss_calculation_fn["vl"] = self.classifier_loss_calculation self.loss_calculation_fn["glue"] = self.classifier_loss_calculation self.losses_dict = {} self.losses_dict["vl"] = { name: self.get_loss_fn(self.config.heads["vl"][name]["loss_type"]) for name in self.config.heads["vl"] } self.losses_dict["glue"] = { name: self.get_loss_fn(self.config.heads["glue"][name]["loss_type"]) for name in self.config.heads["glue"] }
def __init__(self, config, v_feature_size): super(BertImagePredictionHead, self).__init__() self.transform = BertPredictionHeadTransform(config) self.decoder = nn.Linear(config.hidden_size, v_feature_size)