def __init__(self, config): """Initialize the model with config dict. Args: config: python dict must contains the attributes below: config.bert_model_path: pretrained model path or model type e.g. 'bert-base-chinese' config.hidden_size: The same as BERT model, usually 768 config.num_classes: int, e.g. 2 config.dropout: float between 0 and 1 """ super().__init__() if 'xl' in config.model_type: self.bert = AutoModel.from_pretrained(config.bert_model_path) else: self.bert = BertModel.from_pretrained(config.bert_model_path) for param in self.bert.parameters(): param.requires_grad = True self.dropout = nn.Dropout(config.dropout) self.linear = nn.Linear(4, config.num_classes) self.num_classes = config.num_classes self.dim_capsule = config.dim_capsule self.num_compressed_capsule = config.num_compressed_capsule self.ngram_size = [2, 4, 8] self.convs_doc = nn.ModuleList([ nn.Conv1d(config.max_seq_len, 32, K, stride=2) for K in self.ngram_size ]) torch.nn.init.xavier_uniform_(self.convs_doc[0].weight) torch.nn.init.xavier_uniform_(self.convs_doc[1].weight) torch.nn.init.xavier_uniform_(self.convs_doc[2].weight) self.primary_capsules_doc = PrimaryCaps(num_capsules=self.dim_capsule, in_channels=32, out_channels=32, kernel_size=1, stride=1) self.flatten_capsules = FlattenCaps() if config.hidden_size == 768: self.W_doc = nn.Parameter( torch.FloatTensor(147328, self.num_compressed_capsule)) else: #1024 self.W_doc = nn.Parameter( torch.FloatTensor(196480, self.num_compressed_capsule)) torch.nn.init.xavier_uniform_(self.W_doc) self.fc_capsules_doc_child = FCCaps( config, output_capsule_num=config.num_classes, input_capsule_num=self.num_compressed_capsule, in_channels=self.dim_capsule, out_channels=self.dim_capsule)
def __init__(self, config): super(BertSupportNetX, self).__init__() self.encoder = BertModel.from_pretrained(config.bert_model_path) self.config = config # 就是args self.max_query_length = self.config.max_query_len self.input_dim = config.hidden_size self.dropout_size = config.dropout self.dropout = nn.Dropout(self.dropout_size) self.dim_capsule = config.dim_capsule self.num_compressed_capsule = config.num_compressed_capsule self.ngram_size = [2, 4, 8] self.convs_doc = nn.ModuleList([ nn.Conv1d(config.max_seq_len, 32, K, stride=2) for K in self.ngram_size ]) torch.nn.init.xavier_uniform_(self.convs_doc[0].weight) torch.nn.init.xavier_uniform_(self.convs_doc[1].weight) torch.nn.init.xavier_uniform_(self.convs_doc[2].weight) self.primary_capsules_doc = PrimaryCaps(num_capsules=self.dim_capsule, in_channels=32, out_channels=32, kernel_size=1, stride=1) self.flatten_capsules = FlattenCaps() self.W_doc = nn.Parameter( torch.FloatTensor(49024, self.num_compressed_capsule)) torch.nn.init.xavier_uniform_(self.W_doc) self.fc_capsules_doc_child = FCCaps( config, output_capsule_num=config.num_classes, input_capsule_num=self.num_compressed_capsule, in_channels=self.dim_capsule, out_channels=self.dim_capsule) self.start_linear = nn.Linear(self.input_dim * 2, 1) self.end_linear = nn.Linear(self.input_dim * 2, 1) self.type_linear = nn.Linear(self.input_dim, config.num_classes) # yes/no/ans/unknown self.sp_linear = nn.Linear(self.input_dim, 1) self.cache_S = 0 self.cache_mask = None