Exemple #1
0
 def _init_classifier(self, hidden_size):
     if "pretraining" in self.config.training_head_type:
         self.classifier = BertPreTrainingHeads(self.bert_config)
     if "vqa" in self.config.training_head_type:
         self.dropout = nn.Dropout(self.bert_config.hidden_dropout_prob)
         self.answer_space_size = 3129
         self.classifier = nn.Sequential(
             BertPredictionHeadTransform(self.bert_config),
             nn.Linear(self.bert_config.hidden_size,
                       self.answer_space_size),
         )
         # self.classifier = nn.Linear(self.bert_config.hidden_size,
         # self.answer_space_size)
     elif "vizwiz" in self.config.training_head_type:
         self.dropout = nn.Dropout(self.bert_config.hidden_dropout_prob)
         self.answer_space_size = 7371
         self.classifier = nn.Sequential(
             BertPredictionHeadTransform(self.bert_config),
             nn.Linear(self.bert_config.hidden_size,
                       self.answer_space_size),
         )
         # self.classifier = nn.Linear(self.bert_config.hidden_size,
         # self.answer_space_size)
     elif self.config.training_head_type == "visual_entailment":
         self.dropout = nn.Dropout(self.bert_config.hidden_dropout_prob)
         self.classifier = nn.Sequential(
             BertPredictionHeadTransform(self.bert_config),
             nn.Linear(self.bert_config.hidden_size, 3),
         )
Exemple #2
0
    def __init__(self, config: BertConfig):
        super().__init__()

        self.transform1 = BertPredictionHeadTransform(config)
        self.transform2 = BertPredictionHeadTransform(config)

        self.decoder = nn.Linear(config.hidden_size, config.hidden_size)
Exemple #3
0
    def __init__(self, args, base_model_name='bert-base-uncased'):
        super(DialogBERT, self).__init__()

        if args.language == 'chinese': base_model_name = 'bert-base-chinese'

        self.tokenizer = BertTokenizer.from_pretrained(base_model_name,
                                                       cache_dir='./cache/')
        if args.model_size == 'tiny':
            self.encoder_config = BertConfig(vocab_size=30522,
                                             hidden_size=256,
                                             num_hidden_layers=6,
                                             num_attention_heads=2,
                                             intermediate_size=1024)
            self.utt_encoder = BertForPreTraining(self.encoder_config)
        elif args.model_size == 'small':
            self.encoder_config = BertConfig(vocab_size=30522,
                                             hidden_size=512,
                                             num_hidden_layers=8,
                                             num_attention_heads=4,
                                             intermediate_size=2048)
            self.utt_encoder = BertForPreTraining(self.encoder_config)
        else:
            self.encoder_config = BertConfig.from_pretrained(
                base_model_name, cache_dir='./cache/')
            self.utt_encoder = BertForPreTraining.from_pretrained(
                base_model_name,
                config=self.encoder_config,
                cache_dir='./cache/')

        self.context_encoder = BertModel(
            self.encoder_config)  # context encoder: encode context to vector

        self.mlm_mode = 'mse'  # 'mdn', 'mse'
        if self.mlm_mode == 'mdn':
            self.context_mlm_trans = MixtureDensityNetwork(
                self.encoder_config.hidden_size,
                self.encoder_config.hidden_size, 3)
        else:
            self.context_mlm_trans = BertPredictionHeadTransform(
                self.encoder_config
            )  # transform context hidden states back to utterance encodings

        self.dropout = nn.Dropout(self.encoder_config.hidden_dropout_prob)
        self.context_order_trans = SelfSorting(self.encoder_config.hidden_size)
        #       self.context_order_trans = MLP(self.encoder_config.hidden_size, '200-200-200', 1)

        self.decoder_config = deepcopy(self.encoder_config)
        self.decoder_config.is_decoder = True
        self.decoder_config.add_cross_attention = True
        self.decoder = BertLMHeadModel(self.decoder_config)
Exemple #4
0
    def __init__(self, config):
        super(BertImagePredictionHead, self).__init__()
        self.transform = BertPredictionHeadTransform(config)

        # The output weights are the same as the input embeddings, but there is
        # an output-only bias for each token.
        self.decoder = nn.Linear(config.hidden_size, config.target_size)
    def build(self):
        self.text_processor = registry.get(self._datasets[0] +
                                           "_text_processor")
        self.vocab = self.text_processor.vocab
        self.word_embedding = self.vocab.get_embedding(
            torch.nn.Embedding,
            freeze=False,
            embedding_dim=self.config.text_embedding.embedding_dim)
        self.segment_embeddings = nn.Embedding(self.config.num_segment_type,
                                               self.config.hidden_size)

        self.cls_project = nn.Linear(self.config.text_embedding.embedding_dim,
                                     self.config.hidden_size)
        self.lstm = nn.LSTM(**self.config.lstm)
        self.lstm_proj = nn.Linear(self.config.hidden_size * 2,
                                   self.config.hidden_size)
        self.img_encoder = ImageClevrEncoder(self.config)
        self.img_pos_emb = nn.Linear(2, self.config.hidden_size)

        self.LayerNorm = nn.LayerNorm(self.config.hidden_size,
                                      eps=self.config.layer_norm_eps)
        self.dropout = nn.Dropout(self.config.hidden_dropout_prob)

        self.bert_config = BertConfig.from_dict(
            OmegaConf.to_container(self.config, resolve=True))
        self.transformer = BertEncoder(self.bert_config)
        self.pooler = BertPooler(self.bert_config)

        self.classifier = nn.Sequential(
            BertPredictionHeadTransform(self.config),
            nn.Linear(self.config.hidden_size, self.config.num_labels),
        )

        self.head_mask = [None for _ in range(self.config.num_hidden_layers)]
Exemple #6
0
    def __init__(self,
                 config: BertConfig,
                 bert_model_embedding_weights,
                 position_embedding_size=200):
        super().__init__()
        self.position_embeddings = nn.Embedding(config.max_position_embeddings,
                                                position_embedding_size)

        self.pos_emb_proj = nn.Linear(position_embedding_size,
                                      config.hidden_size)

        self.transform = BertPredictionHeadTransform(config)
        # The output weights are the same as the input embeddings, but there is
        # an output-only bias for each token.
        self.decoder = nn.Linear(bert_model_embedding_weights.size(1),
                                 bert_model_embedding_weights.size(0),
                                 bias=False)
        self.decoder.weight = bert_model_embedding_weights
        self.bias = nn.Parameter(
            torch.zeros(bert_model_embedding_weights.size(0)))

        # position_ids (1, len position emb) is contiguous in memory and exported when serialized
        self.register_buffer(
            "position_ids",
            torch.arange(config.max_position_embeddings).expand((1, -1)))
Exemple #7
0
    def __init__(self, config, extra_config):
        super().__init__()
        self.config = config
        self.output_attentions = self.config.output_attentions
        self.output_hidden_states = self.config.output_hidden_states
        self.pooler_strategy = self.config.get("pooler_strategy", "default")

        # Graph input params
        self.feed_graph_to_vb = extra_config["feed_graph_to_vb"]
        self.graph_node_hid_dim = extra_config["node_hid_dim"]
        self.graph_feed_mode = extra_config["feed_mode"]
        self.graph_topk = extra_config["topk_ans_feed"]

        # If doing graph, make a graph embedding layer
        if self.feed_graph_to_vb:
            self.graph_embedding = nn.Sequential(
                nn.Linear(self.graph_node_hid_dim, config.hidden_size),
                nn.LayerNorm(config.hidden_size, eps=1e-12),
                nn.Dropout(config.hidden_dropout_prob),  # hidden_dropout_prb
            )

        # If bert_model_name is not specified, you will need to specify
        # all of the required parameters for BERTConfig and a pretrained
        # model won't be loaded
        self.bert_model_name = self.config.get("bert_model_name", None)
        self.bert_config = BertConfig.from_dict(
            OmegaConf.to_container(self.config, resolve=True)
        )
        if self.bert_model_name is None or self.bert_model_name == "nopretrain":
            self.bert = VisualBERTBase(
                self.bert_config,
                visual_embedding_dim=self.config.visual_embedding_dim,
                embedding_strategy=self.config.embedding_strategy,
                bypass_transformer=self.config.bypass_transformer,
                output_attentions=self.config.output_attentions,
                output_hidden_states=self.config.output_hidden_states,
            )
        else:
            self.bert = VisualBERTBase.from_pretrained(
                self.config.bert_model_name,
                config=self.bert_config,
                cache_dir=os.path.join(
                    get_mmf_cache_dir(), "distributed_{}".format(-1)
                ),
                visual_embedding_dim=self.config.visual_embedding_dim,
                embedding_strategy=self.config.embedding_strategy,
                bypass_transformer=self.config.bypass_transformer,
                output_attentions=self.config.output_attentions,
                output_hidden_states=self.config.output_hidden_states,
            )

        self.training_head_type = self.config.training_head_type
        self.dropout = nn.Dropout(self.bert.config.hidden_dropout_prob)
        if self.config.training_head_type == "nlvr2":
            self.bert.config.hidden_size *= 2
        self.classifier = nn.Sequential(BertPredictionHeadTransform(self.bert.config))

        self.init_weights()
Exemple #8
0
    def __init__(self, config, decoder_weight):
        super(BertLMPredictionHead, self).__init__()
        self.transform = BertPredictionHeadTransform(config)

        # The output weights are the same as the input embeddings, but there is
        # an output-only bias for each token.
        self.decoder_weight = decoder_weight

        self.bias = nn.Parameter(torch.zeros(config.vocab_size))
Exemple #9
0
 def from_pretrained(self, model_dir):
     self.encoder_config = BertConfig.from_pretrained(model_dir)
     self.tokenizer = BertTokenizer.from_pretrained(
         path.join(model_dir, 'tokenizer'),
         do_lower_case=args.do_lower_case)
     self.utt_encoder = BertForPreTraining.from_pretrained(
         path.join(model_dir, 'utt_encoder'))
     self.context_encoder = BertForSequenceClassification.from_pretrained(
         path.join(model_dir, 'context_encoder'))
     self.context_mlm_trans = BertPredictionHeadTransform(
         self.encoder_config)
     self.context_mlm_trans.load_state_dict(
         torch.load(path.join(model_dir, 'context_mlm_trans.pkl')))
     self.context_order_trans = SelfSorting(self.encoder_config.hidden_size)
     self.context_order_trans.load_state_dict(
         torch.load(path.join(model_dir, 'context_order_trans.pkl')))
     self.decoder_config = BertConfig.from_pretrained(model_dir)
     self.decoder = BertLMHeadModel.from_pretrained(
         path.join(model_dir, 'decoder'))
Exemple #10
0
    def __init__(self, config, *args, **kwargs):
        super().__init__()
        self.config = config
        self.bert = MMBTBase(config, *args, **kwargs)
        self.encoder_config = self.bert.encoder_config

        self.dropout = nn.Dropout(self.encoder_config.hidden_dropout_prob)
        self.classifier = nn.Sequential(
            BertPredictionHeadTransform(self.encoder_config),
            nn.Linear(self.encoder_config.hidden_size, self.config.num_labels),
        )
 def __init__(self, config):
     super(BertForImageCaptioning, self).__init__(config)
     self.config = config
     self.bert = BertImgModel(config)
     self.transform = BertPredictionHeadTransform(config)
     bert_embedding_weight = self.bert.embeddings.word_embeddings.weight
     self.decoder = nn.Linear(bert_embedding_weight.size(1),
                              bert_embedding_weight.size(0),
                              bias=False)
     self.loss = nn.CrossEntropyLoss(reduction='mean')
     self.drop_worst_ratio = 0.2
Exemple #12
0
 def __init(self, config=None, *args, **kwargs):
     super().__init__(*args, **kwargs)
     if config is None:
         from transformers.configuration_bert import BertConfig
         config = BertConfig.from_pretrained('bert-base-uncased')
         assert config.hidden_size == self.in_dim
     from transformers.modeling_bert import BertPredictionHeadTransform
     self.module = nn.Sequential(
         nn.Dropout(config.hidden_dropout_prob),
         BertPredictionHeadTransform(config),
         nn.Linear(self.in_dim, self.out_dim),
     )
Exemple #13
0
    def __init__(self, config):
        super(NodeConstructOutputLayer, self).__init__()
        self.transform = BertPredictionHeadTransform(config)

        # The output weights are the same as the input embeddings, but there is
        # an output-only bias for each token.
        self.decoder = nn.Linear(config.hidden_size, config.x_size, bias=False)

        self.bias = nn.Parameter(torch.zeros(config.x_size))

        # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
        self.decoder.bias = self.bias
Exemple #14
0
    def __init__(self, config: Config, *args, **kwargs):
        super().__init__(config, *args, **kwargs)

        # Head modules
        self.pooler = BertPooler(self.config)
        self.classifier = nn.Sequential(
            nn.Dropout(self.config.hidden_dropout_prob),
            BertPredictionHeadTransform(self.config),
            nn.Linear(self.config.hidden_size, self.config.num_labels),
        )
        self.num_labels = self.config.num_labels
        self.hidden_size = self.config.hidden_size
Exemple #15
0
    def __init__(self, config):
        super().__init__()
        self.transform = BertPredictionHeadTransform(config)

        self.visual_losses = config.visual_losses

        # The output weights are the same as the input embeddings, but there is
        # an output-only bias for each token.
        self.decoder_dict = nn.ModuleDict({
            key: nn.Linear(config.hidden_size,
                           config.visual_loss_config[key][0])
            for key in self.visual_losses
        })
Exemple #16
0
 def build_heads(self):
     """Initialize the classifier head. It takes the output of the
     transformer encoder and passes it through a pooler (we use the pooler from BERT
     model), then dropout, BertPredictionHeadTransform (which is a linear layer,
     followed by activation and layer norm) and lastly a linear layer projecting the
     hidden output to classification labels.
     """
     self.classifier = nn.Sequential(
         BertPooler(self.transformer_config),
         nn.Dropout(self.transformer_config.hidden_dropout_prob),
         BertPredictionHeadTransform(self.transformer_config),
         nn.Linear(self.transformer_config.hidden_size, self.config.num_labels),
     )
Exemple #17
0
    def __init__(self, config, bert_model_embedding_weights):
        super().__init__()
        self.transform = BertPredictionHeadTransform(config)

        # The output weights are the same as the input embeddings, but there is
        # an output-only bias for each token.
        self.decoder = nn.Linear(
            bert_model_embedding_weights.size(1),
            bert_model_embedding_weights.size(0),
            bias=False,
        )
        self.decoder.weight = bert_model_embedding_weights
        self.bias = nn.Parameter(
            torch.zeros(bert_model_embedding_weights.size(0)))
Exemple #18
0
    def __init__(self, config, *args, **kwargs):
        super().__init__()
        self.config = config
        self.bert = MMBTBase(config, *args, **kwargs)
        self.encoder_config = self.bert.encoder_config
        self.num_labels = self.config.num_labels
        self.output_hidden_states = self.encoder_config.output_hidden_states
        self.output_attentions = self.encoder_config.output_attentions
        self.fused_feature_only = self.config.fused_feature_only

        self.dropout = nn.Dropout(self.encoder_config.hidden_dropout_prob)
        self.classifier = nn.Sequential(
            BertPredictionHeadTransform(self.encoder_config),
            nn.Linear(self.encoder_config.hidden_size, self.config.num_labels),
        )
Exemple #19
0
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.output_attentions = self.config.output_attentions
        self.output_hidden_states = self.config.output_hidden_states
        self.pooler_strategy = self.config.get("pooler_strategy", "default")

        # If bert_model_name is not specified, you will need to specify
        # all of the required parameters for BERTConfig and a pretrained
        # model won't be loaded
        self.bert_model_name = getattr(self.config, "bert_model_name", None)
        self.bert_config = BertConfig.from_dict(
            OmegaConf.to_container(self.config, resolve=True)
        )
        if self.bert_model_name is None:
            self.bert = VisualBERTBase(
                self.bert_config,
                visual_embedding_dim=self.config.visual_embedding_dim,
                embedding_strategy=self.config.embedding_strategy,
                bypass_transformer=self.config.bypass_transformer,
                output_attentions=self.config.output_attentions,
                output_hidden_states=self.config.output_hidden_states,
            )
        else:
            self.bert = VisualBERTBase.from_pretrained(
                self.config.bert_model_name,
                config=self.bert_config,
                cache_dir=os.path.join(
                    get_multimodelity_cache_dir(), "distributed_{}".format(-1)
                ),
                visual_embedding_dim=self.config.visual_embedding_dim,
                embedding_strategy=self.config.embedding_strategy,
                bypass_transformer=self.config.bypass_transformer,
                output_attentions=self.config.output_attentions,
                output_hidden_states=self.config.output_hidden_states,
            )

        self.training_head_type = self.config.training_head_type
        self.num_labels = self.config.num_labels
        self.dropout = nn.Dropout(self.bert.config.hidden_dropout_prob)
        if self.config.training_head_type == "nlvr2":
            self.bert.config.hidden_size *= 2
        self.classifier = nn.Sequential(
            BertPredictionHeadTransform(self.bert.config),
            nn.Linear(self.bert.config.hidden_size, self.config.num_labels),
        )

        self.init_weights()
Exemple #20
0
    def __init__(self, **kwargs):
        super().__init__()
        self.config = kwargs
        self.output_attentions = self.config['output_attentions']
        self.output_hidden_states = self.config['output_hidden_states']
        self.pooler_strategy = self.config.get('pooler_strategy', 'default')

        # If bert_model_name is not specified, you will need to specify
        # all of the required parameters for BERTConfig and a pretrained
        # model won't be loaded
        self.bert_model_name = self.config['bert_model_name']
        self.bert_config = BertConfig.from_dict(self.config)
        if self.bert_model_name is None:
            self.bert = VisualBERTBase(
                self.bert_config,
                visual_embedding_dim=self.config['visual_embedding_dim'],
                embedding_strategy=self.config['embedding_strategy'],
                bypass_transformer=self.config['bypass_transformer'],
                output_attentions=self.config['output_attentions'],
                output_hidden_states=self.config['output_hidden_states'],
            )
        else:
            from imix.utils.config import ToExpanduser
            cache_dir = os.path.join('~/.cache/torch', 'transformers')
            cache_dir = ToExpanduser.modify_path(cache_dir)

            self.bert = VisualBERTBase.from_pretrained(
                self.config['bert_model_name'],
                config=self.bert_config,
                cache_dir=cache_dir,
                visual_embedding_dim=self.config['visual_embedding_dim'],
                embedding_strategy=self.config['embedding_strategy'],
                bypass_transformer=self.config['bypass_transformer'],
                output_attentions=self.config['output_attentions'],
                output_hidden_states=self.config['output_hidden_states'],
            )

        self.training_head_type = self.config['training_head_type']
        self.num_labels = self.config['num_labels']
        self.dropout = nn.Dropout(self.bert.config.hidden_dropout_prob)
        if self.config['training_head_type'] == 'nlvr2':
            self.bert.config.hidden_size *= 2
        self.classifier = nn.Sequential(
            BertPredictionHeadTransform(self.bert.config),
            nn.Linear(self.bert.config.hidden_size, self.config['num_labels']),
        )

        self.init_weights()
Exemple #21
0
 def build_heads(self):
     """Initialize the classifier head. It takes the output of the
     transformer encoder and passes it through a pooler (we use the pooler from BERT
     model), then dropout, BertPredictionHeadTransform (which is a linear layer,
     followed by activation and layer norm) and lastly a linear layer projecting the
     hidden output to classification labels.
     """
     transformer_config = self.backend.get_config()
     if self.config.training_head_type == "classification":
         self.pooler = BertPooler(transformer_config)
         self.classifier = nn.Sequential(
             nn.Dropout(transformer_config.hidden_dropout_prob),
             BertPredictionHeadTransform(transformer_config),
             nn.Linear(transformer_config.hidden_size, self.config.num_labels),
         )
     elif self.config.training_head_type == "pretraining":
         self.cls = BertOnlyMLMHead(transformer_config)
         self.vocab_size = transformer_config.vocab_size
    def __init__(self, config):
        super().__init__()
        self.config = config
        # self.output_attentions = self.config.output_attentions
        self.output_hidden_states = self.config.output_hidden_states

        self.bert = VisualBERTBase(
            self.config,
            visual_embedding_dim=self.config.visual_embedding_dim,
        )

        self.num_labels = self.config.num_labels
        self.dropout = nn.Dropout(self.bert.config.hidden_dropout_prob)
        # if self.config.training_head_type == "nlvr2":
        #     self.bert.config.hidden_size *= 2

        self.classifier = nn.Sequential(
            BertPredictionHeadTransform(self.bert.config),
            nn.Linear(self.bert.config.hidden_size, self.config.num_labels),
        )

        self.init_weights()
Exemple #23
0
    def __init__(self, config, bert_model_embedding_weights, hidden_size=768):
        super(BertLMPredictionHead, self).__init__()
        config.hidden_size = hidden_size
        self.transform = BertPredictionHeadTransform(config)

        # The output weights are the same as the input embeddings, but there is
        # an output-only bias for each token.
        if hidden_size == 768:
            self.decoder = nn.Linear(
                bert_model_embedding_weights.size(1),
                bert_model_embedding_weights.size(0),
                bias=False,
            )
            # print(self.decoder.weight.shape) # 30522, 768
            self.decoder.weight = bert_model_embedding_weights
        else:
            self.decoder = nn.Linear(
                hidden_size,
                bert_model_embedding_weights.size(0),
                bias=False,
            )
        self.bias = nn.Parameter(
            torch.zeros(bert_model_embedding_weights.size(0)))
 def __init__(self, config, num_labels):
     super(PredictionHead, self).__init__()
     self.transform = BertPredictionHeadTransform(config)
     self.decoder = nn.Linear(config.hidden_size, num_labels)
     self.bias = nn.Parameter(torch.zeros(num_labels))
Exemple #25
0
 def __init__(self, config):
     super().__init__()
     self.transform = BertPredictionHeadTransform(config)
     self.decoder = nn.Linear(config.hidden_size,
                              config.vocab_size,
                              bias=False)
Exemple #26
0
class DialogBERT(nn.Module):
    '''Hierarchical BERT for dialog v5 with two features:
    - Masked context utterances prediction with direct MSE matching of their vectors
    - Energy-based Utterance order prediction: A novel approach to shuffle the context and predict the original order with distributed order prediction'''

    # TODO: 1. Enhance sorting net
    #       2. Better data loader for permutation ((avoid returning perm_id and use max(pos_ids) instead,

    def __init__(self, args, base_model_name='bert-base-uncased'):
        super(DialogBERT, self).__init__()

        if args.language == 'chinese': base_model_name = 'bert-base-chinese'

        self.tokenizer = BertTokenizer.from_pretrained(base_model_name,
                                                       cache_dir='./cache/')
        if args.model_size == 'tiny':
            self.encoder_config = BertConfig(vocab_size=30522,
                                             hidden_size=256,
                                             num_hidden_layers=6,
                                             num_attention_heads=2,
                                             intermediate_size=1024)
            self.utt_encoder = BertForPreTraining(self.encoder_config)
        elif args.model_size == 'small':
            self.encoder_config = BertConfig(vocab_size=30522,
                                             hidden_size=512,
                                             num_hidden_layers=8,
                                             num_attention_heads=4,
                                             intermediate_size=2048)
            self.utt_encoder = BertForPreTraining(self.encoder_config)
        else:
            self.encoder_config = BertConfig.from_pretrained(
                base_model_name, cache_dir='./cache/')
            self.utt_encoder = BertForPreTraining.from_pretrained(
                base_model_name,
                config=self.encoder_config,
                cache_dir='./cache/')

        self.context_encoder = BertModel(
            self.encoder_config)  # context encoder: encode context to vector

        self.mlm_mode = 'mse'  # 'mdn', 'mse'
        if self.mlm_mode == 'mdn':
            self.context_mlm_trans = MixtureDensityNetwork(
                self.encoder_config.hidden_size,
                self.encoder_config.hidden_size, 3)
        else:
            self.context_mlm_trans = BertPredictionHeadTransform(
                self.encoder_config
            )  # transform context hidden states back to utterance encodings

        self.dropout = nn.Dropout(self.encoder_config.hidden_dropout_prob)
        self.context_order_trans = SelfSorting(self.encoder_config.hidden_size)
        #       self.context_order_trans = MLP(self.encoder_config.hidden_size, '200-200-200', 1)

        self.decoder_config = deepcopy(self.encoder_config)
        self.decoder_config.is_decoder = True
        self.decoder_config.add_cross_attention = True
        self.decoder = BertLMHeadModel(self.decoder_config)

    def init_weights(self, m):  # Initialize Linear Weight for GAN
        if isinstance(m, nn.Linear):
            m.weight.data.uniform_(-0.08,
                                   0.08)  #nn.init.xavier_normal_(m.weight)
            nn.init.constant_(m.bias, 0.)

    @classmethod
    def from_pretrained(self, model_dir):
        self.encoder_config = BertConfig.from_pretrained(model_dir)
        self.tokenizer = BertTokenizer.from_pretrained(
            path.join(model_dir, 'tokenizer'),
            do_lower_case=args.do_lower_case)
        self.utt_encoder = BertForPreTraining.from_pretrained(
            path.join(model_dir, 'utt_encoder'))
        self.context_encoder = BertForSequenceClassification.from_pretrained(
            path.join(model_dir, 'context_encoder'))
        self.context_mlm_trans = BertPredictionHeadTransform(
            self.encoder_config)
        self.context_mlm_trans.load_state_dict(
            torch.load(path.join(model_dir, 'context_mlm_trans.pkl')))
        self.context_order_trans = SelfSorting(self.encoder_config.hidden_size)
        self.context_order_trans.load_state_dict(
            torch.load(path.join(model_dir, 'context_order_trans.pkl')))
        self.decoder_config = BertConfig.from_pretrained(model_dir)
        self.decoder = BertLMHeadModel.from_pretrained(
            path.join(model_dir, 'decoder'))

    def save_pretrained(self, output_dir):
        def save_module(model, save_path):
            torch.save(model_to_save.state_dict(), save_path)

        def make_list_dirs(dir_list):
            for dir_ in dir_list:
                os.makedirs(dir_, exist_ok=True)

        make_list_dirs([
            path.join(output_dir, name) for name in
            ['tokenizer', 'utt_encoder', 'context_encoder', 'decoder']
        ])
        model_to_save = self.module if hasattr(self, 'module') else self
        model_to_save.encoder_config.save_pretrained(
            output_dir)  # Save configuration file
        model_to_save.tokenizer.save_pretrained(
            path.join(output_dir, 'tokenizer'))
        model_to_save.utt_encoder.save_pretrained(
            path.join(output_dir, 'utt_encoder'))
        model_to_save.context_encoder.save_pretrained(
            path.join(output_dir, 'context_encoder'))
        save_module(model_to_save.context_mlm_trans,
                    path.join(output_dir, 'context_mlm_trans.pkl'))
        save_module(model_to_save.context_order_trans,
                    path.join(output_dir, 'context_order_trans.pkl'))
        model_to_save.decoder_config.save_pretrained(
            output_dir)  # Save configuration file
        model_to_save.decoder.save_pretrained(path.join(output_dir, 'decoder'))

    def utt_encoding(self, context, utts_attn_mask):
        batch_size, max_ctx_len, max_utt_len = context.size(
        )  #context: [batch_size x diag_len x max_utt_len]

        utts = context.view(
            -1, max_utt_len)  # [(batch_size*diag_len) x max_utt_len]
        utts_attn_mask = utts_attn_mask.view(-1, max_utt_len)
        _, utts_encodings, *_ = self.utt_encoder.bert(utts, utts_attn_mask)
        utts_encodings = utts_encodings.view(batch_size, max_ctx_len, -1)
        return utts_encodings

    def context_encoding(self, context, utts_attn_mask, ctx_attn_mask):
        #with torch.no_grad():
        utt_encodings = self.utt_encoding(context, utts_attn_mask)
        context_hiddens, pooled_output, *_ = self.context_encoder(
            None, ctx_attn_mask, None, None, None, utt_encodings)
        # context_hiddens:[batch_size x ctx_len x dim]; pooled_output=[batch_size x dim]

        return context_hiddens, pooled_output

    def train_dialog_flow(self, context, context_utts_attn_mask,
                          context_attn_mask, context_lm_targets,
                          context_position_perm_id, context_position_ids,
                          response):
        """
        only train the dialog flow model
        """
        self.context_encoder.train()  # set the module in training mode.
        self.context_mlm_trans.train()

        context_hiddens, context_encoding = self.context_encoding(
            context, context_utts_attn_mask, context_attn_mask)
        lm_pred_encodings = self.context_mlm_trans(
            self.dropout(context_hiddens))

        context_lm_targets[context_lm_targets == -100] = 0
        ctx_lm_mask = context_lm_targets.sum(2)
        if (ctx_lm_mask > 0).sum() == 0: ctx_lm_mask[0, 0] = 1
        lm_pred_encodings = lm_pred_encodings[ctx_lm_mask > 0]
        context_lm_targets = context_lm_targets[ctx_lm_mask > 0]
        context_lm_targets_attn_mask = context_utts_attn_mask[ctx_lm_mask > 0]

        with torch.no_grad():
            _, lm_tgt_encodings, *_ = self.utt_encoder.bert(
                context_lm_targets, context_lm_targets_attn_mask)

        loss_ctx_mlm = MSELoss()(lm_pred_encodings,
                                 lm_tgt_encodings)  # [num_selected_utts x dim]

        # context order prediction
        if isinstance(self.context_order_trans, SelfSorting):
            sorting_scores = self.context_order_trans(context_hiddens,
                                                      context_attn_mask)
        else:
            sorting_scores = self.context_order_trans(context_hiddens)
        sorting_pad_mask = context_attn_mask == 0
        sorting_pad_mask[
            context_position_perm_id <
            1] = True  # exclude single-turn and unshuffled dialogs
        loss_ctx_uop = listNet(sorting_scores, context_position_ids,
                               sorting_pad_mask)
        #loss_ctx_uop = listMLE(sorting_scores, context_position_ids, sorting_pad_mask)

        loss = loss_ctx_mlm + loss_ctx_uop

        return {
            'loss': loss,
            'loss_ctx_mlm': loss_ctx_lm,
            'loss_ctx_uop': loss_ctx_uop
        }

    def train_decoder(self, context, context_utts_attn_mask, context_attn_mask,
                      context_lm_targets, context_position_perm_id,
                      context_position_ids, response):
        """
         only train the decoder
         """
        self.decoder.train()

        with torch.no_grad():
            context_hiddens, context_encoding = self.context_encoding(
                context, context_utts_attn_mask, context_attn_mask)

        ## train decoder
        dec_input, dec_target = response[:, :-1].contiguous(
        ), response[:, 1:].clone()

        dec_output, *_ = self.decoder(
            dec_input,
            dec_input.ne(self.tokenizer.pad_token_id).long(),
            None,
            None,
            None,
            None,
            encoder_hidden_states=context_hiddens,
            encoder_attention_mask=context_attn_mask,
        )

        batch_size, seq_len, vocab_size = dec_output.size()
        dec_target[response[:, 1:] == self.tokenizer.pad_token_id] = -100
        dec_target[context_position_perm_id >
                   1] == -100  # ignore responses whose contexts are shuffled
        loss_decoder = CrossEntropyLoss()(dec_output.view(-1, vocab_size),
                                          dec_target.view(-1))

        results = {'loss': loss_decoder, 'loss_decoder': loss_decoder}

        return results

    def forward(self, context, context_utts_attn_mask, context_attn_mask,
                context_mlm_targets, context_position_perm_id,
                context_position_ids, response):
        self.train()
        batch_size, max_ctx_len, max_utt_len = context.size(
        )  #context: [batch_size x diag_len x max_utt_len]

        context_hiddens, context_encoding = self.context_encoding(
            context, context_utts_attn_mask, context_attn_mask)

        ## train dialog flow modeling
        context_mlm_targets[context_mlm_targets == -100] = 0
        ctx_mlm_mask = context_mlm_targets.sum(2)  #[batch_size x num_utts]
        if (ctx_mlm_mask > 0).sum() == 0: ctx_mlm_mask[0, 0] = 1
        ctx_mlm_mask = ctx_mlm_mask > 0

        with torch.no_grad():
            _, mlm_tgt_encodings, *_ = self.utt_encoder.bert(
                context_mlm_targets[ctx_mlm_mask],
                context_utts_attn_mask[ctx_mlm_mask])

        if self.mlm_mode == 'mdn':  # mixture density network
            mlm_pred_pi, mlm_pred_normal = self.context_mlm_trans(
                self.dropout(context_hiddens[ctx_mlm_mask]))
            loss_ctx_mlm = self.context_mlm_trans.loss(mlm_pred_pi,
                                                       mlm_pred_normal,
                                                       mlm_tgt_encodings)
        else:  # simply mean square loss
            mlm_pred_encodings = self.context_mlm_trans(
                self.dropout(context_hiddens[ctx_mlm_mask]))
            loss_ctx_mlm = MSELoss()(
                mlm_pred_encodings,
                mlm_tgt_encodings)  # [num_selected_utts x dim]

        # context order prediction
        if isinstance(self.context_order_trans, SelfSorting):
            sorting_scores = self.context_order_trans(context_hiddens,
                                                      context_attn_mask)
        else:
            sorting_scores = self.context_order_trans(context_hiddens)
        sorting_pad_mask = context_attn_mask == 0
        sorting_pad_mask[
            context_position_perm_id <
            1] = True  # exclude single-turn and unshuffled dialogs
        loss_ctx_uop = listNet(sorting_scores, context_position_ids,
                               sorting_pad_mask)
        #loss_ctx_uop = listMLE(sorting_scores, context_position_ids, sorting_pad_mask)

        ## train decoder
        dec_input, dec_target = response[:, :-1].contiguous(
        ), response[:, 1:].clone()

        dec_output, *_ = self.decoder(
            dec_input,
            dec_input.ne(self.tokenizer.pad_token_id).long(),
            None,
            None,
            None,
            None,
            encoder_hidden_states=context_hiddens,
            encoder_attention_mask=context_attn_mask,
        )

        batch_size, seq_len, vocab_size = dec_output.size()
        dec_target[response[:, 1:] == self.tokenizer.pad_token_id] = -100
        dec_target[context_position_perm_id >
                   1] = -100  # ignore responses whose context was shuffled
        loss_decoder = CrossEntropyLoss()(dec_output.view(-1, vocab_size),
                                          dec_target.view(-1))

        loss = loss_ctx_mlm + loss_ctx_uop + loss_decoder

        results = {
            'loss': loss,
            'loss_ctx_mlm': loss_ctx_mlm,
            'loss_ctx_uop': loss_ctx_uop,
            'loss_decoder': loss_decoder
        }

        return results

    def validate(self, context, context_utts_attn_mask, context_attn_mask,
                 context_lm_targets, context_position_perm_id,
                 context_position_ids, response):
        results = self.train_decoder(context, context_utts_attn_mask,
                                     context_attn_mask, context_lm_targets,
                                     context_position_perm_id,
                                     context_position_ids, response)
        return results['loss'].item()

    def generate(self, input_batch, max_len=30, num_samples=1, mode='sample'):
        self.eval()
        device = next(self.parameters()).device
        context, context_utts_attn_mask, context_attn_mask = [
            t.to(device) for t in input_batch[:3]
        ]
        ground_truth = input_batch[6].numpy()

        context_hiddens, context_encoding = self.context_encoding(
            context, context_utts_attn_mask, context_attn_mask)

        generated = torch.zeros(
            (num_samples, 1), dtype=torch.long,
            device=device).fill_(self.tokenizer.cls_token_id)
        # [batch_sz x 1] (1=seq_len)

        sample_lens = torch.ones((num_samples, 1),
                                 dtype=torch.long,
                                 device=device)
        len_inc = torch.ones((num_samples, 1), dtype=torch.long, device=device)
        for _ in range(max_len):
            outputs, *_ = self.decoder(
                generated,
                generated.ne(self.tokenizer.pad_token_id).long(),
                None,
                None,
                None,
                None,
                encoder_hidden_states=context_hiddens,
                encoder_attention_mask=context_attn_mask,
            )  # [batch_size x seq_len x vocab_size]
            next_token_logits = outputs[:,
                                        -1, :] / self.decoder_config.temperature

            # repetition penalty from CTRL (https://arxiv.org/abs/1909.05858)
            for i in range(num_samples):
                for _ in set(generated[i].tolist()):
                    next_token_logits[
                        i, _] /= self.decoder_config.repetition_penalty

            filtered_logits = top_k_top_p_filtering(
                next_token_logits,
                top_k=self.decoder_config.top_k,
                top_p=self.decoder_config.top_p)
            if mode == 'greedy':  # greedy sampling:
                next_token = torch.argmax(filtered_logits,
                                          dim=-1).unsqueeze(-1)
            else:
                next_token = torch.multinomial(torch.softmax(filtered_logits,
                                                             dim=-1),
                                               num_samples=num_samples)
            next_token[len_inc == 0] = self.tokenizer.pad_token_id
            generated = torch.cat((generated, next_token), dim=1)
            len_inc = len_inc * (
                next_token != self.tokenizer.sep_token_id).long(
                )  # stop incresing length (set 0 bit) when EOS is encountered
            if len_inc.sum() < 1: break
            sample_lens = sample_lens + len_inc

        # to numpy
        sample_words = generated.data.cpu().numpy()
        sample_lens = sample_lens.data.cpu().numpy()

        context = context.data.cpu().numpy()
        return sample_words, sample_lens, context, ground_truth  # nparray: [repeat x seq_len]
Exemple #27
0
    def __init__(self, config, src_len):
        super(LayoutlmSPLMPredictionHead, self).__init__()
        self.transform = BertPredictionHeadTransform(config)

        self.bias = nn.Parameter(torch.zeros(src_len))
Exemple #28
0
    def build(self):
        # build the base model (based on DETR)
        self.unit_base_model = UniTBaseModel(self.config.base_args)

        def keep_only_backbone_params(model_state_dict):
            keys = list(model_state_dict.keys())
            for k in keys:
                if "backbone" not in k:
                    model_state_dict.pop(k)

        ckpt_path = self.config.base_ckpt_path
        if ckpt_path != "":
            logger.info(f"initializing base model (UniT) from {ckpt_path}")
            if ckpt_path.startswith("https"):
                base_checkpoint = torch.hub.load_state_dict_from_url(
                    ckpt_path, check_hash=True)
            else:
                base_checkpoint = torch.load(ckpt_path)
            if self.config.base_ckpt_load_backbone_only:
                keep_only_backbone_params(base_checkpoint["model"])
                self.unit_base_model.load_state_dict(base_checkpoint["model"],
                                                     strict=False)
            else:
                self.unit_base_model.load_state_dict(base_checkpoint["model"],
                                                     strict=True)

        # build the text encoder (BERT)
        self.bert_model = TransformerEncoder(self.config.base_args.bert_config)
        detr_hidden_dim = self.config.base_args.decoder_hidden_dim
        bert_config = deepcopy(self.bert_model.config)
        self.bert_projection = nn.Linear(bert_config.hidden_size,
                                         detr_hidden_dim)
        self.bert_pos_projection = nn.Linear(bert_config.hidden_size,
                                             detr_hidden_dim)

        self.classifiers = nn.ModuleDict()

        self.task_embeddings_lang = nn.Identity()
        if self.config.base_args.use_task_embedding_in_lang_encoder:
            self.task_embeddings_lang = nn.Embedding(self.config.max_task_num,
                                                     bert_config.hidden_size)

        bert_config.hidden_size = detr_hidden_dim

        # build the task-specific output heads
        self.class_embeds = nn.ModuleDict()
        self.bbox_embeds = nn.ModuleDict()
        self.det_losses = nn.ModuleDict()
        for dataset_name in self.config.base_args.num_queries.get(
                "detection", []):
            num_cls = self.config.heads["detection"][dataset_name][
                "num_classes"]
            self.class_embeds[dataset_name] = nn.Linear(
                detr_hidden_dim, num_cls + 1)
            self.bbox_embeds[dataset_name] = MLP(detr_hidden_dim,
                                                 detr_hidden_dim, 4, 3)
            attr_head = None
            if self.config.heads["detection"][dataset_name]["use_attr"]:
                attr_head = AttributeHead(
                    num_cls, self.config.base_args.attribute_class_num,
                    detr_hidden_dim)
            self.det_losses[dataset_name] = build_detection_loss(
                self.config.base_args, num_cls, attr_head)

        vl_classifiers = nn.ModuleDict()
        for dataset_name in self.config.base_args.num_queries.get("vl", []):
            vl_classifiers[dataset_name] = nn.Sequential(
                BertPredictionHeadTransform(bert_config),
                nn.Linear(
                    bert_config.hidden_size,
                    self.config.heads["vl"][dataset_name]["num_labels"],
                ),
            )

        self.classifiers["vl"] = vl_classifiers
        self.dropout = nn.Dropout(bert_config.hidden_dropout_prob)

        glue_classifiers = nn.ModuleDict()
        for dataset_name in self.config.base_args.num_queries.get("glue", []):
            glue_classifiers[dataset_name] = nn.Sequential(
                BertPredictionHeadTransform(bert_config),
                nn.Linear(
                    bert_config.hidden_size,
                    self.config.heads["glue"][dataset_name]["num_labels"],
                ),
            )

        self.classifiers["glue"] = glue_classifiers

        self.loss_calculation_fn = {}
        self.loss_calculation_fn["detection"] = self.detection_loss_calculation
        self.loss_calculation_fn["vl"] = self.classifier_loss_calculation
        self.loss_calculation_fn["glue"] = self.classifier_loss_calculation

        self.losses_dict = {}
        self.losses_dict["vl"] = {
            name: self.get_loss_fn(self.config.heads["vl"][name]["loss_type"])
            for name in self.config.heads["vl"]
        }
        self.losses_dict["glue"] = {
            name:
            self.get_loss_fn(self.config.heads["glue"][name]["loss_type"])
            for name in self.config.heads["glue"]
        }
 def __init__(self, config, v_feature_size):
     super(BertImagePredictionHead, self).__init__()
     self.transform = BertPredictionHeadTransform(config)
     self.decoder = nn.Linear(config.hidden_size, v_feature_size)