def weights_init(m): classname = m.__class__.__name__ if classname.find('Linear') != -1 or classname.find('Conv1d') != -1: if hasattr(m, 'weight') and m.weight is not None: GraphTransformer.init_weight(m.weight) if hasattr(m, 'bias') and m.bias is not None: GraphTransformer.init_bias(m.bias)
def __init__(self, vocabs, word_char_dim, word_dim, concept_char_dim, concept_dim, cnn_filters, char2word_dim, char2concept_dim, rel_dim, rnn_hidden_size, rnn_num_layers, embed_dim, ff_embed_dim, num_heads, dropout, snt_layers, graph_layers, inference_layers, pretrained_file, device): super(Generator, self).__init__() self.vocabs = vocabs self.embed_dim = embed_dim self.embed_scale = math.sqrt(embed_dim) self.dropout = dropout self.device = device # encoder self.concept_encoder = TokenEncoder(vocabs['concept'], vocabs['concept_char'], concept_char_dim, concept_dim, embed_dim, cnn_filters, char2concept_dim, dropout, pretrained_file) self.concept_depth = nn.Embedding(256, embed_dim) self.relation_encoder = RelationEncoder(vocabs['relation'], rel_dim, embed_dim, rnn_hidden_size, rnn_num_layers, dropout) self.graph_encoder = GraphTransformer(graph_layers, embed_dim, ff_embed_dim, num_heads, dropout) self.concept_embed_layer_norm = nn.LayerNorm(embed_dim) self.probe_generator = nn.Linear(embed_dim, embed_dim) # decoder self.token_encoder = TokenEncoder(vocabs['token'], vocabs['token_char'], word_char_dim, word_dim, embed_dim, cnn_filters, char2word_dim, dropout, pretrained_file) self.snt_encoder = Transformer(snt_layers, embed_dim, ff_embed_dim, num_heads, dropout, with_external=True) self.token_position = SinusoidalPositionalEmbedding(embed_dim, device) self.token_embed_layer_norm = nn.LayerNorm(embed_dim) self.self_attn_mask = SelfAttentionMask(device) #------------ self.decoder = DecodeLayer(vocabs, inference_layers, embed_dim, ff_embed_dim, num_heads, concept_dim, rel_dim, dropout) self.reset_parameters()
def __init__(self, vocabs, concept_char_dim, concept_dim, cnn_filters, char2concept_dim, rel_dim, rnn_hidden_size, rnn_num_layers, embed_dim, bert_embed_dim, ff_embed_dim, num_heads, dropout, snt_layer, graph_layers, pretrained_file, device, batch_size, model_type, bert_config, bert_model, bert_tokenizer, bert_max_length, n_answers, model, max_conceptnet_length, conceptnet_path): super(Reasoning_AMR_CN_DUAL, self).__init__() self.vocabs = vocabs self.embed_scale = math.sqrt(embed_dim) self.concept_encoder = TokenEncoder(vocabs['concept'], vocabs['concept_char'], concept_char_dim, concept_dim, embed_dim, cnn_filters, char2concept_dim, dropout, pretrained_file) self.relation_encoder = RelationEncoder(vocabs['relation'], rel_dim, embed_dim, rnn_hidden_size, rnn_num_layers, dropout) self.graph_encoder = GraphTransformer(graph_layers, embed_dim, ff_embed_dim, num_heads, dropout) self.transformer = Transformer(snt_layer, embed_dim, ff_embed_dim, num_heads, dropout, with_external=True) self.c_transformer = Transformer(snt_layer, embed_dim, ff_embed_dim, num_heads, dropout, with_external=True) self.f_transformer = Transformer(snt_layer, bert_embed_dim + embed_dim, ff_embed_dim, num_heads, dropout, with_external=True) self.pretrained_file = pretrained_file self.embed_dim = embed_dim self.concept_dim = concept_dim self.max_conceptnet_len = max_conceptnet_length self.embed_scale = math.sqrt(embed_dim) self.token_position = SinusoidalPositionalEmbedding(embed_dim, device) self.concept_depth = nn.Embedding(32, embed_dim) self.token_embed_layer_norm = nn.LayerNorm(embed_dim) self.concept_embed_layer_norm = nn.LayerNorm(embed_dim) self.self_attn_mask = SelfAttentionMask(device) self.dropout = dropout self.conceptnet_path = conceptnet_path self.bert_config = bert_config self.bert_model = bert_model self.bert_tokenizer = bert_tokenizer self.bert_max_length = bert_max_length self.answer_len = n_answers self.device = device self.model_type = model_type self.model = model self.batch_size = batch_size self.classifier = nn.Linear(bert_embed_dim + embed_dim, 1) self.loss_fct = CrossEntropyLoss() self.reset_parameters()
class Reasoning_AMR_CN_DUAL(nn.Module): def __init__(self, vocabs, concept_char_dim, concept_dim, cnn_filters, char2concept_dim, rel_dim, rnn_hidden_size, rnn_num_layers, embed_dim, bert_embed_dim, ff_embed_dim, num_heads, dropout, snt_layer, graph_layers, pretrained_file, device, batch_size, model_type, bert_config, bert_model, bert_tokenizer, bert_max_length, n_answers, model, max_conceptnet_length, conceptnet_path): super(Reasoning_AMR_CN_DUAL, self).__init__() self.vocabs = vocabs self.embed_scale = math.sqrt(embed_dim) self.concept_encoder = TokenEncoder(vocabs['concept'], vocabs['concept_char'], concept_char_dim, concept_dim, embed_dim, cnn_filters, char2concept_dim, dropout, pretrained_file) self.relation_encoder = RelationEncoder(vocabs['relation'], rel_dim, embed_dim, rnn_hidden_size, rnn_num_layers, dropout) self.graph_encoder = GraphTransformer(graph_layers, embed_dim, ff_embed_dim, num_heads, dropout) self.transformer = Transformer(snt_layer, embed_dim, ff_embed_dim, num_heads, dropout, with_external=True) self.c_transformer = Transformer(snt_layer, embed_dim, ff_embed_dim, num_heads, dropout, with_external=True) self.f_transformer = Transformer(snt_layer, bert_embed_dim + embed_dim, ff_embed_dim, num_heads, dropout, with_external=True) self.pretrained_file = pretrained_file self.embed_dim = embed_dim self.concept_dim = concept_dim self.max_conceptnet_len = max_conceptnet_length self.embed_scale = math.sqrt(embed_dim) self.token_position = SinusoidalPositionalEmbedding(embed_dim, device) self.concept_depth = nn.Embedding(32, embed_dim) self.token_embed_layer_norm = nn.LayerNorm(embed_dim) self.concept_embed_layer_norm = nn.LayerNorm(embed_dim) self.self_attn_mask = SelfAttentionMask(device) self.dropout = dropout self.conceptnet_path = conceptnet_path self.bert_config = bert_config self.bert_model = bert_model self.bert_tokenizer = bert_tokenizer self.bert_max_length = bert_max_length self.answer_len = n_answers self.device = device self.model_type = model_type self.model = model self.batch_size = batch_size self.classifier = nn.Linear(bert_embed_dim + embed_dim, 1) self.loss_fct = CrossEntropyLoss() self.reset_parameters() def reset_parameters(self): nn.init.constant_(self.concept_depth.weight, 0.) def encoder_attn(self, inp): with torch.no_grad(): concept_repr = self.embed_scale * self.concept_encoder( inp['concept'], inp['concept_char'] + self.concept_depth(inp['concept_depth'])) concept_repr = self.concept_embed_layer_norm(concept_repr) concept_mask = torch.eq(inp['concept'], self.vocabs['concept'].padding_idx) relation = self.relation_encoder(inp['relation_bank'], inp['relation_length']) relation[0, :] = 0. relation = relation[inp['relation']] sum_relation = relation.sum(dim=3) num_valid_paths = inp['relation'].ne(0).sum(dim=3).clamp_(min=1) divisor = (num_valid_paths).unsqueeze(-1).type_as(sum_relation) relation = sum_relation / divisor attn, attn_weights = self.graph_encoder.get_attn_weights( concept_repr, relation, self_padding_mask=concept_mask) return attn def encode_cn_step(self, inp, i, train=True): cn_concept_input = inp['cn_concept'][i][:, 0].unsqueeze(1) cn_concept_char_input = inp['cn_concept_char'][i][:, 0].unsqueeze(1) cn_concept_depth_input = inp['cn_concept_depth'][i][:, 0].unsqueeze(1) cn_relation_bank_input = inp['cn_relation_bank'][i] cn_relation_length_input = inp['cn_relation_length'][i] cn_relation_input = inp['cn_relation'][i][:, :, 0].unsqueeze(2) concept_repr = self.embed_scale * self.concept_encoder( cn_concept_input, cn_concept_char_input) + self.concept_depth(cn_concept_depth_input) concept_repr = self.concept_embed_layer_norm(concept_repr) concept_mask = torch.eq(cn_concept_input, self.vocabs['concept'].padding_idx) relation = self.relation_encoder(cn_relation_bank_input, cn_relation_length_input) if str(train) == 'True': relation = relation.index_select( 0, cn_relation_input.reshape(-1)).view(*cn_relation_input.size(), -1) else: relation[0, :] = 0. # cn_relation_length x dim relation = relation[cn_relation_input] # i x j x bsz x num x dim sum_relation = relation.sum(dim=3) # i x j x bsz x dim num_valid_paths = cn_relation_input.ne(0).sum(dim=3).clamp_( min=1) # i x j x bsz divisor = (num_valid_paths).unsqueeze(-1).type_as( sum_relation) # i x j x bsz x 1 relation = sum_relation / divisor # i x j x bsz dim concept_repr = self.graph_encoder(concept_repr, relation, self_padding_mask=concept_mask) return concept_repr def encoder_cn_attn(self, inp, i): with torch.no_grad(): cn_concept_input = inp['cn_concept'][i][:, 0].unsqueeze(1) cn_concept_char_input = inp['cn_concept_char'][i][:, 0].unsqueeze(1) cn_concept_depth_input = inp['cn_concept_depth'][i][:, 0].unsqueeze(1) cn_relation_bank_input = inp['cn_relation_bank'][i] cn_relation_length_input = inp['cn_relation_length'][i] cn_relation_input = inp['cn_relation'][i][:, :, 0].unsqueeze(2) concept_repr = self.embed_scale * self.concept_encoder( cn_concept_input, cn_concept_char_input) + self.concept_depth( cn_concept_depth_input) concept_repr = self.concept_embed_layer_norm(concept_repr) concept_mask = torch.eq(cn_concept_input, self.vocabs['concept'].padding_idx) relation = self.relation_encoder( cn_relation_bank_input, cn_relation_length_input) # [211, 512] relation[0, :] = 0. # cn_relation_length x dim relation = relation[cn_relation_input] # i x j x bsz x num x dim sum_relation = relation.sum(dim=3) # i x j x bsz x dim num_valid_paths = cn_relation_input.ne(0).sum(dim=3).clamp_( min=1) # i x j x bsz divisor = (num_valid_paths).unsqueeze(-1).type_as( sum_relation) # i x j x bsz x 1 relation = sum_relation / divisor # i x j x bsz x dim concept_repr = self.graph_encoder(concept_repr, relation, self_padding_mask=concept_mask) attn = self.graph_encoder.get_attn_weights( concept_repr, relation, self_padding_mask=concept_mask) return attn def convert_batch_to_bert_features(self, data, max_seq_length, tokenizer, cls_token_at_end=False, cls_token='[CLS]', cls_token_segment_id=1, sep_token='[SEP]', sequence_a_segment_id=0, sequence_b_segment_id=1, sep_token_extra=False, pad_token_segment_id=0, pad_on_left=False, pad_token=0, mask_padding_with_zero=True): features = [] questions = [" ".join(x for x in sent) for sent in data['token_data']] answers = data['answers'] choices_features = [] for i, text in enumerate(questions): question = text for j, ans in enumerate(answers[i]): answer = answers[i][j] token_a = tokenizer.tokenize(question) token_b = tokenizer.tokenize(answer) tokens = token_a + [sep_token] segment_ids = [sequence_a_segment_id] * len(tokens) if token_b: tokens += token_b + [sep_token] segment_ids += [sequence_b_segment_id] * (len(token_b) + 1) if cls_token_at_end: tokens = tokens + [cls_token] segment_ids = segment_ids + [cls_token_segment_id] else: tokens = [cls_token] + tokens segment_ids = [cls_token_segment_id] + segment_ids input_ids = tokenizer.convert_tokens_to_ids(tokens) input_mask = [1 if mask_padding_with_zero else 0 ] * len(input_ids) # Zero-pad up to the sequence length. padding_length = max_seq_length - len(input_ids) if padding_length > 0: if pad_on_left: input_ids = ([pad_token] * padding_length) + input_ids input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids else: input_ids = input_ids + ([pad_token] * padding_length) input_mask = input_mask + ( [0 if mask_padding_with_zero else 1] * padding_length) segment_ids = segment_ids + ([pad_token_segment_id] * padding_length) else: input_ids = input_ids[:max_seq_length] input_mask = input_ids[:max_seq_length] segment_ids = segment_ids[:max_seq_length] assert len(input_ids) == max_seq_length assert len(input_mask) == max_seq_length assert len(segment_ids) == max_seq_length choices_features.append((input_ids, input_mask, segment_ids)) features.append(choices_features) choices_features = [] return features def prepare_bert_input(self, data, tokenizer): move_to_cuda(data, self.device) features = self.convert_batch_to_bert_features( data=data, max_seq_length=self.bert_max_length, tokenizer=tokenizer, ) move_to_cuda(features, self.device) all_input_ids = torch.tensor( [f[0] for feature in features for f in feature], dtype=torch.long).view(self.batch_size, -1, self.bert_max_length).to(self.device) all_input_mask = torch.tensor( [f[1] for feature in features for f in feature], dtype=torch.long).view(self.batch_size, -1, self.bert_max_length).to(self.device) all_segment_ids = torch.tensor( [f[2] for feature in features for f in feature], dtype=torch.long).view(self.batch_size, -1, self.bert_max_length).to(self.device) return all_input_ids, all_input_mask, all_segment_ids def prepare_graph_state(self, graph_state, ans_len, concept_dim): tot_initial = torch.tensor(1).to(self.device) j = 0 while j < (1 * ans_len) - 1: initial = graph_state[0][j].view(1, -1).to(self.device) for i in graph_state[1:]: # i = [5 x 512] x 7개 com_tensor = i[j + 1].view(1, -1).to(self.device) initial = torch.cat([initial, com_tensor], dim=0) if j == 0: tot_initial = initial.view(1, -1, concept_dim) j += 1 initial = initial.view(1, -1, concept_dim) tot_initial = torch.cat([tot_initial, initial], dim=0) return tot_initial def forward(self, data, train): answer_len = self.answer_len tot_concept_reprs = [] for i in range(self.batch_size): ## AMR-GTOS concept_repr = self.encode_cn_step( data, i, train=train) # concept_seq_len x 1 x concept_embed_size concept_repr = self.transformer( concept_repr, kv=None) # res = concept_seq_len x bsz x concept_embed_size if concept_repr.size()[1] == 1: concept_repr = concept_repr.squeeze().unsqueeze(0).mean( 1).unsqueeze(1) else: concept_repr = self.prepare_graph_state( concept_repr, concept_repr.size()[1], self.embed_dim).mean(1).unsqueeze( 1) # re = bsz x 1 x concept_embed_size concept_repr = concept_repr.repeat( 1, answer_len, 1) # re = 1 x 5 x concept_embed_size tot_concept_repr = self.c_transformer(concept_repr, kv=None) tot_concept_reprs.append(tot_concept_repr) tot_concept_reprs = torch.squeeze(torch.stack(tot_concept_reprs), 1) ids = data['id'] labels = data['answer_key'] labels = torch.tensor(labels).to(self.device) if labels.dim() == 0: labels = labels.unsqueeze(0) ### prepare bert input # 3 x 5 x 128 all_input_ids, all_input_mask, all_segment_ids = self.prepare_bert_input( data, self.bert_tokenizer) logits = self.bert_model( input_ids=all_input_ids, attention_mask=all_input_mask, token_type_ids=all_segment_ids if self.model_type in ['bert-base-cased', 'xlnet-base-cased'] else None, labels=labels, n_answers=answer_len) bsz = len(ids) final_logits = torch.cat([logits, tot_concept_reprs], 2) final_logits = self.f_transformer(final_logits, kv=None) final_logits = self.classifier(final_logits).squeeze().view(bsz, -1) return final_logits, labels, ids
class Generator(nn.Module): def __init__(self, vocabs, word_char_dim, word_dim, concept_char_dim, concept_dim, cnn_filters, char2word_dim, char2concept_dim, rel_dim, rnn_hidden_size, rnn_num_layers, embed_dim, ff_embed_dim, num_heads, dropout, snt_layers, graph_layers, inference_layers, pretrained_file, device): super(Generator, self).__init__() self.vocabs = vocabs self.embed_dim = embed_dim self.embed_scale = math.sqrt(embed_dim) self.dropout = dropout self.device = device # encoder self.concept_encoder = TokenEncoder(vocabs['concept'], vocabs['concept_char'], concept_char_dim, concept_dim, embed_dim, cnn_filters, char2concept_dim, dropout, pretrained_file) self.concept_depth = nn.Embedding(256, embed_dim) self.relation_encoder = RelationEncoder(vocabs['relation'], rel_dim, embed_dim, rnn_hidden_size, rnn_num_layers, dropout) self.graph_encoder = GraphTransformer(graph_layers, embed_dim, ff_embed_dim, num_heads, dropout) self.concept_embed_layer_norm = nn.LayerNorm(embed_dim) self.probe_generator = nn.Linear(embed_dim, embed_dim) # decoder self.token_encoder = TokenEncoder(vocabs['token'], vocabs['token_char'], word_char_dim, word_dim, embed_dim, cnn_filters, char2word_dim, dropout, pretrained_file) self.snt_encoder = Transformer(snt_layers, embed_dim, ff_embed_dim, num_heads, dropout, with_external=True) self.token_position = SinusoidalPositionalEmbedding(embed_dim, device) self.token_embed_layer_norm = nn.LayerNorm(embed_dim) self.self_attn_mask = SelfAttentionMask(device) #------------ self.decoder = DecodeLayer(vocabs, inference_layers, embed_dim, ff_embed_dim, num_heads, concept_dim, rel_dim, dropout) self.reset_parameters() def reset_parameters(self): nn.init.normal_(self.probe_generator.weight, std=0.02) nn.init.constant_(self.probe_generator.bias, 0.) nn.init.constant_(self.concept_depth.weight, 0.) def encoder_attn(self, inp): with torch.no_grad(): concept_repr = self.embed_scale * self.concept_encoder( inp['concept'], inp['concept_char']) + self.concept_depth( inp['concept_depth']) concept_repr = self.concept_embed_layer_norm(concept_repr) concept_mask = torch.eq(inp['concept'], self.vocabs['concept'].padding_idx) relation = self.relation_encoder(inp['relation_bank'], inp['relation_length']) relation = relation.index_select( 0, inp['relation'].contiguous().view(-1)).contiguous().view( *inp['relation'].size(), -1) attn = self.graph_encoder.get_attn_weights( concept_repr, relation, self_padding_mask=concept_mask) # nlayers x tgt_len x src_len x bsz x num_heads return attn def encode_step(self, inp): concept_repr = self.embed_scale * self.concept_encoder( inp['concept'], inp['concept_char']) + self.concept_depth( inp['concept_depth']) concept_repr = self.concept_embed_layer_norm(concept_repr) concept_mask = torch.eq(inp['concept'], self.vocabs['concept'].padding_idx) relation = self.relation_encoder(inp['relation_bank'], inp['relation_length']) relation = relation.index_select( 0, inp['relation'].contiguous().view(-1)).contiguous().view( *inp['relation'].size(), -1) concept_repr = self.graph_encoder(concept_repr, relation, self_padding_mask=concept_mask) #------------ probe = torch.tanh(self.probe_generator(concept_repr[:1])) concept_repr = concept_repr[1:] concept_mask = concept_mask[1:] return concept_repr, concept_mask, probe def work(self, data, beam_size, max_time_step, min_time_step=1): with torch.no_grad(): concept_repr, concept_mask, probe = self.encode_step(data) mem_dict = { 'graph_state': concept_repr, 'graph_padding_mask': concept_mask, 'probe': probe, 'local_idx2token': data['local_idx2token'], 'cp_seq': data['cp_seq'] } init_state_dict = {} init_hyp = Hypothesis(init_state_dict, [STR], 0.) bsz = concept_repr.size(1) beams = [ Beam(beam_size, min_time_step, max_time_step, [init_hyp], self.device) for i in range(bsz) ] search_by_batch(self, beams, mem_dict) return beams def prepare_incremental_input(self, step_seq): token = ListsToTensor(step_seq, self.vocabs['token']) token_char = ListsofStringToTensor(step_seq, self.vocabs['token_char']) token, token_char = move_to_device(token, self.device), move_to_device( token_char, self.device) return token, token_char def decode_step(self, inp, state_dict, mem_dict, offset, topk): step_token, step_token_char = inp graph_repr = mem_dict['graph_state'] graph_padding_mask = mem_dict['graph_padding_mask'] probe = mem_dict['probe'] copy_seq = mem_dict['cp_seq'] local_vocabs = mem_dict['local_idx2token'] _, bsz, _ = graph_repr.size() token_repr = self.embed_scale * self.token_encoder( step_token, step_token_char) + self.token_position( step_token, offset) token_repr = self.token_embed_layer_norm(token_repr) new_state_dict = {} for idx, layer in enumerate(self.snt_encoder.layers): name_i = 'token_repr_%d' % idx if name_i in state_dict: prev_token_repr = state_dict[name_i] new_token_repr = torch.cat([prev_token_repr, token_repr], 0) else: new_token_repr = token_repr new_state_dict[name_i] = new_token_repr token_repr, _, _ = layer(token_repr, kv=new_token_repr, external_memories=graph_repr, external_padding_mask=graph_padding_mask) name = 'token_state' if name in state_dict: prev_token_state = state_dict[name] new_token_state = torch.cat([prev_token_state, token_repr], 0) else: new_token_state = token_repr new_state_dict[name] = new_token_state #------------------ LL = self.decoder(probe, graph_repr, new_token_state, graph_padding_mask, None, None, copy_seq, work=True) def idx2token(idx, local_vocab): if idx in local_vocab: return local_vocab[idx] return self.vocabs['predictable_token'].idx2token(idx) topk_scores, topk_token = torch.topk(LL.squeeze(0), topk, 1) # bsz x k results = [] for s, t, local_vocab in zip(topk_scores.tolist(), topk_token.tolist(), local_vocabs): res = [] for score, token in zip(s, t): res.append((idx2token(token, local_vocab), score)) results.append(res) return new_state_dict, results def forward(self, data): concept_repr, concept_mask, probe = self.encode_step(data) # decoding token_repr = self.embed_scale * self.token_encoder( data['token_in'], data['token_char_in']) + self.token_position( data['token_in']) token_repr = self.token_embed_layer_norm(token_repr) token_repr = F.dropout(token_repr, p=self.dropout, training=self.training) token_mask = torch.eq(data['token_in'], self.vocabs['token'].padding_idx) #-------- attn_mask = self.self_attn_mask(data['token_in'].size(0)) if 1: print('token_in:', data['token_in'].size()) print(data['token_in']) print('token_repr:', token_repr.size()) print('token_mask:', token_mask.size()) print('attn_mask:', attn_mask.size()) print(attn_mask) print('----------------') #assert False token_repr = self.snt_encoder(token_repr, self_padding_mask=token_mask, self_attn_mask=attn_mask, external_memories=concept_repr, external_padding_mask=concept_mask) print('token_repr:', token_repr.size()) assert False probe = probe.expand_as(token_repr) # tgt_len x bsz x embed_dim return self.decoder(probe, concept_repr, token_repr, concept_mask, token_mask, attn_mask, \ data['cp_seq'], target=data['token_out'])
GraphTransformer.init_weight(m.weight) if hasattr(m, 'bias') and m.bias is not None: GraphTransformer.init_bias(m.bias) model = GraphTransformer( dim=650, # n_layers=14, n_layers=1, d_inner=3800, fdim=200, final_dim=280, dropout=0.03, dropatt=0.0, final_dropout=0.04, n_head=10, num_atom_types=NUM_ATOM_TYPES, num_bond_types=NUM_BOND_TYPES, num_triplet_types=NUM_TRIPLET_TYPES, num_quad_types=NUM_QUAD_TYPES, dist_embedding='sine', atom_angle_embedding='sine', trip_angle_embedding='sine', quad_angle_embedding='sine', wnorm=True, use_quad=True, ).to(device) model.eval() # %% for step, batch in enumerate(loader): batch = [