Example #1
0
 def __init__(self, args, answer_num):
     super(Labeler, self).__init__(args, answer_num)
     self.type_embed_dim = 1024
     self.def_embed_dim = 1024
     if args.data_setup == 'joint' and args.multitask:
         print("Multi-task learning")
         self.decoder = MultiSimpleDecoder(self.output_dim +
                                           self.type_embed_dim +
                                           self.def_embed_dim)
     else:
         self.decoder = SimpleDecoder(
             self.output_dim + self.type_embed_dim + self.def_embed_dim,
             answer_num)
     self.type_vocab_size = constant.ANSWER_NUM_DICT['open'] + 3
     self.bos_idx = constant.TYPE_BOS_IDX
     self.eos_idx = constant.TYPE_EOS_IDX
     self.pad_idx = constant.TYPE_PAD_IDX
     self.def_vocab_size = constant.DEF_VOCAB_SIZE
     self.def_pad_idx = constant.DEF_PAD_IDX
     self.type_embedding = nn.Embedding(self.type_vocab_size,
                                        self.type_embed_dim,
                                        padding_idx=self.pad_idx)
     self.def_embedding = nn.Embedding(self.def_vocab_size,
                                       self.def_embed_dim,
                                       padding_idx=self.def_pad_idx)
     self.lstm_def = nn.LSTM(self.def_embed_dim,
                             self.def_embed_dim // 2,
                             bidirectional=True,
                             batch_first=True)
     self.def_attentive_sum = TypeAttentiveSum(self.output_dim,
                                               self.type_hid_dim)
Example #2
0
    def __init__(self, args, answer_num):
        super(ETModel, self).__init__(args, answer_num)
        self.output_dim = args.rnn_dim * 2
        self.mention_dropout = nn.Dropout(args.mention_dropout)
        self.input_dropout = nn.Dropout(args.input_dropout)
        self.dim_hidden = args.dim_hidden
        self.embed_dim = 1024
        self.mention_dim = 1024
        self.headword_dim = 1024
        self.enhanced_mention = args.enhanced_mention

        self.add_headword_emb = args.add_headword_emb
        self.mention_lstm = args.mention_lstm

        if args.enhanced_mention:
            self.head_attentive_sum = SelfAttentiveSum(self.mention_dim, 1)
            self.cnn = CNN()
            self.mention_dim += 50
        self.output_dim += self.mention_dim

        if self.add_headword_emb:
            self.output_dim += self.headword_dim

        # Defining LSTM here.
        self.attentive_sum = SelfAttentiveSum(args.rnn_dim * 2, 100)
        self.lstm = nn.LSTM(self.embed_dim + 50,
                            args.rnn_dim,
                            bidirectional=True,
                            batch_first=True)
        self.token_mask = nn.Linear(4, 50)

        if self.mention_lstm:
            self.lstm_mention = nn.LSTM(self.embed_dim,
                                        self.embed_dim // 2,
                                        bidirectional=True,
                                        batch_first=True)
            self.mention_attentive_sum = SelfAttentiveSum(self.embed_dim, 1)

        self.sigmoid_fn = nn.Sigmoid()
        self.goal = args.goal

        if args.data_setup == 'joint' and args.multitask:
            print("Multi-task learning")
            self.decoder = MultiSimpleDecoder(self.output_dim)
        else:
            self.decoder = SimpleDecoder(self.output_dim, answer_num)

        self.weighted_sum = ELMoWeightedSum()
Example #3
0
    def __init__(self, args, answer_num):
        super(Model, self).__init__()
        self.output_dim = args.rnn_dim * 2
        self.mention_dropout = nn.Dropout(args.mention_dropout)
        self.input_dropout = nn.Dropout(args.input_dropout)
        self.dim_hidden = args.dim_hidden
        self.embed_dim = 300
        self.mention_dim = 300
        self.lstm_type = args.lstm_type
        self.enhanced_mention = args.enhanced_mention
        if args.enhanced_mention:
            self.head_attentive_sum = SelfAttentiveSum(self.mention_dim, 1)
            self.cnn = CNN()
            self.mention_dim += 50
        self.output_dim += self.mention_dim

        # Defining LSTM here.
        self.attentive_sum = SelfAttentiveSum(args.rnn_dim * 2, 100)
        if self.lstm_type == "two":
            self.left_lstm = nn.LSTM(self.embed_dim,
                                     100,
                                     bidirectional=True,
                                     batch_first=True)
            self.right_lstm = nn.LSTM(self.embed_dim,
                                      100,
                                      bidirectional=True,
                                      batch_first=True)
        elif self.lstm_type == 'single':
            self.lstm = nn.LSTM(self.embed_dim + 50,
                                args.rnn_dim,
                                bidirectional=True,
                                batch_first=True)
            self.token_mask = nn.Linear(4, 50)
        self.loss_func = nn.BCEWithLogitsLoss()
        self.sigmoid_fn = nn.Sigmoid()
        self.goal = args.goal
        self.multitask = args.multitask

        if args.data_setup == 'joint' and args.multitask:
            print("Multi-task learning")
            self.decoder = MultiSimpleDecoder(self.output_dim)
        else:
            self.decoder = SimpleDecoder(self.output_dim, answer_num)
Example #4
0
    def __init__(self, args, answer_num):
        super(Bert, self).__init__(args, answer_num)

        # --- BERT ---
        if args.model_type == 'bert_uncase_small':
            print('==> Loading BERT config from ' +
                  constant.BERT_UNCASED_SMALL_CONFIG)
            self.bert_config = BertConfig.from_json_file(
                constant.BERT_UNCASED_SMALL_CONFIG)
        else:
            raise NotImplementedError
        self.bert = BertModel(self.bert_config)
        self.dropout = nn.Dropout(self.bert_config.hidden_dropout_prob)

        if args.data_setup == 'joint' and args.multitask:
            print("Multi-task learning")
            self.decoder = MultiSimpleDecoder(self.bert_config.hidden_size)
        else:
            self.decoder = SimpleDecoder(self.bert_config.hidden_size,
                                         answer_num)
    def __init__(self, args, answer_num):
        super(Model, self).__init__()
        self.args = args
        self.output_dim = args.rnn_dim * 2
        self.mention_dropout = nn.Dropout(args.mention_dropout)
        self.input_dropout = nn.Dropout(args.input_dropout)
        self.dim_hidden = args.dim_hidden
        self.embed_dim = 300
        self.mention_dim = 300
        self.lstm_type = args.lstm_type
        self.enhanced_mention = args.enhanced_mention
        if args.enhanced_mention:
            self.head_attentive_sum = SelfAttentiveSum(self.mention_dim, 1)
            self.cnn = CNN()
            self.mention_dim += 50
        self.output_dim += self.mention_dim

        if args.model_debug:
            self.mention_proj = nn.Linear(self.mention_dim, 2 * args.rnn_dim)
            self.attn = nn.Linear(2 * args.rnn_dim, 2 * args.rnn_dim)
            self.fusion = Fusion(2 * args.rnn_dim)
            self.output_dim = 2 * args.rnn_dim * 2

        self.batch_num = 0

        if args.add_regu:
            corr_matrix, _, _, mask, mask_inverse = build_concurr_matrix(
                goal=args.goal)
            corr_matrix -= np.identity(corr_matrix.shape[0])
            self.corr_matrix = torch.from_numpy(corr_matrix).to(
                torch.device('cuda')).float()
            self.incon_mask = torch.from_numpy(mask).to(
                torch.device('cuda')).float()
            self.con_mask = torch.from_numpy(mask_inverse).to(
                torch.device('cuda')).float()

            self.b = nn.Parameter(torch.rand(corr_matrix.shape[0], 1))
            self.b_ = nn.Parameter(torch.rand(corr_matrix.shape[0], 1))

        # Defining LSTM here.
        self.attentive_sum = SelfAttentiveSum(args.rnn_dim * 2, 100)
        if self.lstm_type == "two":
            self.left_lstm = nn.LSTM(self.embed_dim,
                                     100,
                                     bidirectional=True,
                                     batch_first=True)
            self.right_lstm = nn.LSTM(self.embed_dim,
                                      100,
                                      bidirectional=True,
                                      batch_first=True)
        elif self.lstm_type == 'single':
            self.lstm = nn.LSTM(self.embed_dim + 50,
                                args.rnn_dim,
                                bidirectional=True,
                                batch_first=True)
            self.token_mask = nn.Linear(4, 50)

        if args.self_attn:
            self.embed_proj = nn.Linear(self.embed_dim + 50, 2 * args.rnn_dim)
            self.encoder = SimpleEncoder(2 * args.rnn_dim,
                                         head=4,
                                         layer=1,
                                         dropout=0.2)

        self.loss_func = nn.BCEWithLogitsLoss()
        self.sigmoid_fn = nn.Sigmoid()
        self.goal = args.goal
        self.multitask = args.multitask

        if args.data_setup == 'joint' and args.multitask and args.gcn:
            print("Multi-task learning with gcn on labels")
            self.decoder = GCNMultiDecoder(self.output_dim)
        elif args.data_setup == 'joint' and args.multitask:
            print("Multi-task learning")
            self.decoder = MultiSimpleDecoder(self.output_dim)
        elif args.data_setup == 'joint' and not args.multitask and args.gcn:
            print("Joint training with GCN simple decoder")
            self.decoder = GCNSimpleDecoder(self.output_dim, answer_num,
                                            "open")
        elif args.goal == 'onto' and args.gcn:
            print("Ontonotes with gcn decoder")
            self.decoder = GCNSimpleDecoder(self.output_dim, answer_num,
                                            "onto")
        else:
            print("Ontonotes using simple decoder")
            self.decoder = SimpleDecoder(self.output_dim, answer_num)