def forward(self, x, w=1.0): isscaler = np.isscalar(w) assert self.padding_idx is not None if isscaler or w.size(0) == 1: weight = sinusoidal_encode(self.weight, w) return F.embedding( x, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse)
def forward(self, inputs): # pylint: disable=arguments-differ original_inputs = inputs if original_inputs.dim() > 2: inputs = inputs.view(-1, inputs.size(-1)) embedded = embedding(inputs, self.weight, max_norm=self.max_norm, norm_type=self.norm_type, scale_grad_by_freq=self.scale_grad_by_freq, sparse=self.sparse) if original_inputs.dim() > 2: view_args = list(original_inputs.size()) + [embedded.size(-1)] embedded = embedded.view(*view_args) if self._projection: projection = self._projection for _ in range(embedded.dim() - 2): projection = TimeDistributed(projection) embedded = projection(embedded) return embedded
def forward(self, words, dropout=0.1, scale=None): if dropout: size = (self.embed.weight.size(0),1) mask = Variable(dropout_mask(self.embed.weight.data, size, dropout)) masked_embed_weight = mask * self.embed.weight else: masked_embed_weight = self.embed.weight if scale: masked_embed_weight = scale * masked_embed_weight padding_idx = self.embed.padding_idx if padding_idx is None: padding_idx = -1 if IS_TORCH_04: X = F.embedding(words, masked_embed_weight, padding_idx, self.embed.max_norm, self.embed.norm_type, self.embed.scale_grad_by_freq, self.embed.sparse) else: X = self.embed._backend.Embedding.apply(words, masked_embed_weight, padding_idx, self.embed.max_norm, self.embed.norm_type, self.embed.scale_grad_by_freq, self.embed.sparse) return X
def loss(self, data): user, item_i, item_j = data user = user.to(self.device) item_i = item_i.to(self.device) item_j = item_j.to(self.device) users_embedding = self.embed_user.weight items_embedding = self.embed_item.weight # pdb.set_trace() gcn1_users_embedding = ( torch.sparse.mm(self.user_item_matrix, items_embedding) + users_embedding.mul(self.d_i_train)) #*2. #+ users_embedding gcn1_items_embedding = ( torch.sparse.mm(self.item_user_matrix, users_embedding) + items_embedding.mul(self.d_j_train)) #*2. #+ items_embedding gcn2_users_embedding = ( torch.sparse.mm(self.user_item_matrix, gcn1_items_embedding) + gcn1_users_embedding.mul(self.d_i_train)) #*2. + users_embedding gcn2_items_embedding = ( torch.sparse.mm(self.item_user_matrix, gcn1_users_embedding) + gcn1_items_embedding.mul(self.d_j_train)) #*2. + items_embedding gcn3_users_embedding = ( torch.sparse.mm(self.user_item_matrix, gcn2_items_embedding) + gcn2_users_embedding.mul(self.d_i_train) ) #*2. + gcn1_users_embedding gcn3_items_embedding = ( torch.sparse.mm(self.item_user_matrix, gcn2_users_embedding) + gcn2_items_embedding.mul(self.d_j_train) ) #*2. + gcn1_items_embedding # gcn4_users_embedding = (torch.sparse.mm(self.user_item_matrix, gcn3_items_embedding) + gcn3_users_embedding.mul(self.d_i_train))#*2. + gcn1_users_embedding # gcn4_items_embedding = (torch.sparse.mm(self.item_user_matrix, gcn3_users_embedding) + gcn3_items_embedding.mul(self.d_j_train))#*2. + gcn1_items_embedding gcn_users_embedding = torch.cat( (users_embedding, gcn1_users_embedding, gcn2_users_embedding, gcn3_users_embedding), -1) #+gcn4_users_embedding gcn_items_embedding = torch.cat( (items_embedding, gcn1_items_embedding, gcn2_items_embedding, gcn3_items_embedding), -1) #+gcn4_items_embedding# self.user_result = gcn_users_embedding self.item_result = gcn_items_embedding user_emb = F.embedding(user, gcn_users_embedding) item_i_emb = F.embedding(item_i - self.num_user, gcn_items_embedding) item_j_emb = F.embedding(item_j - self.num_user, gcn_items_embedding) # # pdb.set_trace() prediction_i = (user_emb * item_i_emb).sum(dim=-1) prediction_j = (user_emb * item_j_emb).sum(dim=-1) # loss=-((rediction_i-prediction_j).sigmoid())**2#self.loss(prediction_i,prediction_j)#.sum() l2_regulization = self.weight_decay * (user_emb**2 + item_i_emb**2 + item_j_emb**2).sum(dim=-1) # l2_regulization = 0.01*((gcn1_users_embedding**2).sum(dim=-1).mean()+(gcn1_items_embedding**2).sum(dim=-1).mean()) # loss2= -((prediction_i - prediction_j).sigmoid().log().mean()) # loss= loss2 + l2_regulization loss = -( (prediction_i - prediction_j)).sigmoid().log().mean() + l2_regulization.mean() # pdb.set_trace() return loss, l2_regulization.mean()
def forward(self, input_dict): # unpack inputs entity_indices = input_dict["entity_indices"] text_indices = input_dict["text_indices"] text_lengths = input_dict["text_lengths"] triple_head_indices = input_dict["triple_head_indices"] triple_relation_indices = input_dict["triple_relation_indices"] triple_tail_indices = input_dict["triple_tail_indices"] adjacents = [ input_dict["adjacent_%d" % i] for i in range(self.config.relation_size + self.config.add_adj_size) if ("adjacent_%d" % i) in input_dict ] # embedding and encoding entity_embeddings = self.token_embedding(entity_indices) text_embeddings = self.token_embedding(text_indices) text_encodings = self.gru(text_embeddings, text_lengths) # shortcuts labeling relation_num = len(adjacents) adj_to_use = [i for i in range(2)] # R-GAT fusion fusioned_entity_embeddings = self.r_gat( [entity_embeddings, text_encodings, adj_to_use] + adjacents) fusioned_entity_embeddings = self.dense(fusioned_entity_embeddings) # DistMult decode triple_heads = F.embedding(triple_head_indices, fusioned_entity_embeddings) triple_tails = F.embedding(triple_tail_indices, fusioned_entity_embeddings) triple_relations = self.relation_embedding(triple_relation_indices) # score mask = [1 if i < 5 else -1 for i in triple_relation_indices] mask = torch.tensor(mask, dtype=torch.float) score = triple_heads * triple_relations * triple_tails # 2600*128 score = torch.sum(score, dim=-1) # 2600 score = torch.dot(score, mask) score = torch.sigmoid(score) # 2600 score = -0.001 * torch.log(score) # Text encoder packed_embedded = nn.utils.rnn.pack_padded_sequence(text_embeddings, text_lengths, batch_first=True) packed_output, (hidden, cell) = self.lstm(packed_embedded) # concat the final forward and backward hidden state hidden = torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1) W = nn.Parameter(torch.randn(1, fusioned_entity_embeddings.shape[0])) transformed_entity_embeddings = torch.mm(W, fusioned_entity_embeddings) # print(transformed_entity_embeddings.shape) # print(hidden.shape) gamma = 0.01 knowledge_guided_hidden = (1 - gamma) * hidden + gamma * self.dense3( transformed_entity_embeddings) predicted = self.dense2(knowledge_guided_hidden) predicted = torch.sigmoid(predicted) predicted = torch.squeeze(predicted, dim=0) return predicted, score
def encode(self, source_padded: torch.Tensor, source_lengths: List[int]) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Apply the encoder to source sentences to obtain encoder hidden states. Additionally, take the final states of the encoder and project them to obtain initial states for decoder. @param source_padded (Tensor): Tensor of padded source sentences with shape (b, src_len), where b = batch_size, src_len = maximum source sentence length. Note that these have already been sorted in order of longest to shortest sentence. @param source_lengths (List[int]): List of actual lengths for each of the source sentences in the batch @returns enc_hiddens (Tensor): Tensor of hidden units with shape (b, src_len, h*2), where b = batch size, src_len = maximum source sentence length, h = hidden size. @returns dec_init_state (tuple(Tensor, Tensor)): Tuple of tensors representing the decoder's initial hidden state and cell. """ enc_hiddens, dec_init_state = None, None ### YOUR CODE HERE (~ 8 Lines) ### TODO: ### 1. Construct Tensor `X` of source sentences with shape (src_len, b, e) using the source model embeddings. ### src_len = maximum source sentence length, b = batch size, e = embedding size. Note ### that there is no initial hidden state or cell for the decoder. ### 2. Compute `enc_hiddens`, `last_hidden`, `last_cell` by applying the encoder to `X`. ### - Before you can apply the encoder, you need to apply the `pack_padded_sequence` function to X. ### - After you apply the encoder, you need to apply the `pad_packed_sequence` function to enc_hiddens. ### - Note that the shape of the tensor returned by the encoder is (src_len b, h*2) and we want to ### return a tensor of shape (b, src_len, h*2) as `enc_hiddens`. ### 3. Compute `dec_init_state` = (init_decoder_hidden, init_decoder_cell): ### - `init_decoder_hidden`: ### `last_hidden` is a tensor shape (2, b, h). The first dimension corresponds to forwards and backwards. ### Concatenate the forwards and backwards tensors to obtain a tensor shape (b, 2*h). ### Apply the h_projection layer to this in order to compute init_decoder_hidden. ### This is h_0^{dec} in the PDF. Here b = batch size, h = hidden size ### - `init_decoder_cell`: ### `last_cell` is a tensor shape (2, b, h). The first dimension corresponds to forwards and backwards. ### Concatenate the forwards and backwards tensors to obtain a tensor shape (b, 2*h). ### Apply the c_projection layer to this in order to compute init_decoder_cell. ### This is c_0^{dec} in the PDF. Here b = batch size, h = hidden size ### ### See the following docs, as you may need to use some of the following functions in your implementation: ### Pack the padded sequence X before passing to the encoder: ### https://pytorch.org/docs/stable/nn.html#torch.nn.utils.rnn.pack_padded_sequence ### Pad the packed sequence, enc_hiddens, returned by the encoder: ### https://pytorch.org/docs/stable/nn.html#torch.nn.utils.rnn.pad_packed_sequence ### Tensor Concatenation: ### https://pytorch.org/docs/stable/torch.html#torch.cat ### Tensor Permute: ### https://pytorch.org/docs/stable/tensors.html#torch.Tensor.permute src_padded_embedding= F.embedding(source_padded, self.model_embeddings.source.weight)#(src_len, b, e) X = nn.utils.rnn.pack_padded_sequence(src_padded_embedding,source_lengths) enc_hiddens,(last_hidden,last_cell) = self.encoder(X) enc_hiddens = nn.utils.rnn.pad_packed_sequence(enc_hiddens,batch_first=True,total_length=source_padded.size(0))#,padding_value='<pad>') enc_hiddens = enc_hiddens[0]#(b, src_len, hidden_size*2) dec_init_state_h = torch.cat((last_hidden[0],last_hidden[1]),1)#(b, hidden_size*2) init_decoder_hidden = self.h_projection(dec_init_state_h)#(b, hidden_size) dec_init_state_c = torch.cat((last_cell[0],last_cell[1]),1) init_decoder_cell = self.c_projection(dec_init_state_c) dec_init_state = (init_decoder_hidden,init_decoder_cell) ### END YOUR CODE return enc_hiddens, dec_init_state
def get_ter_prediction(self, predictor, estimator, sample, xml_model=None): ''' get ter prediction ''' k = self.args.num_experts def get_gap_fea(): ''' get NMT features ''' predictor.eval() encoder = predictor.encoder encoder_out = torch.zeros( (sample['net_input']['src_tokens'].shape[1], sample['net_input']['src_tokens'].shape[0], encoder.output_embed_dim)).to(sample['target'].device) encoder_padding_mask = sample['net_input']['src_tokens'].eq( encoder.embed_positions.padding_idx) if not encoder_padding_mask.any(): encoder_padding_mask = None encoder_out = encoder.encode(sample['net_input']['src_tokens'], encoder_out={ 'encoder_out': encoder_out, 'encoder_padding_mask': encoder_padding_mask }) net_outputs = [] i_equals = [] for i in range(k): decoder = predictor.decoder prev_output_tokens_k = sample['net_input'][ 'prev_output_tokens'].clone() assert not prev_output_tokens_k.requires_grad prev_output_tokens_k[:, 0] = self.expert_index(i) # model derived features and dual model features net_output = predictor.decoder(prev_output_tokens_k, encoder_out) # B x T x dic_size lprobs = predictor.get_normalized_probs(net_output, log_probs=True) target = sample['target'] co_attn = torch.zeros( (sample['target'].shape[1], sample['target'].shape[0], predictor.encoder.output_embed_dim)).to( sample['target'].device) encoder_padding_mask = sample['target'].eq( predictor.encoder.embed_positions.padding_idx) if not encoder_padding_mask.any(): encoder_padding_mask = None enc_out_dual = decoder.encode(sample['target'], encoder_out={ 'encoder_out': co_attn, 'encoder_padding_mask': encoder_padding_mask }) enc_out_dual = enc_out_dual['encoder_out'].transpose(0, 1) lprobs_dual = decoder.output_layer(enc_out_dual) lprobs_dual = utils.log_softmax(lprobs_dual, dim=-1) target_embeding = F.embedding( target, predictor.decoder.embed_tokens.weight) last_output = net_output[1]['last_output'] * target_embeding pre_qefv = torch.mul(last_output, target_embeding) post_qefv = last_output pre_qefv_dual = enc_out_dual * target_embeding * target_embeding post_qefv_dual = enc_out_dual * target_embeding # mismatch features target = target.unsqueeze(-1) i_gt = lprobs.gather(dim=-1, index=target) i_max, i_argmax = lprobs.max(dim=-1, keepdim=True) i_equal = torch.eq(i_argmax, target).type_as(i_gt) i_equals.append(i_equal) i_gap = i_max - i_gt # i_gt_dual = lprobs_dual.gather(dim=-1, index=target) # i_max_dual, i_argmax_dual = lprobs_dual.max(dim=-1, keepdim=True) mismatch_fea = torch.cat([ i_gt, i_max, i_equal, i_gap, ], dim=-1) # i_gt_dual, i_max_dual, i_equal_dual, i_gap_dual, # i_gt-i_gt_dual, i_max-i_max_dual, i_equal-i_equal_dual, i_gap-i_gap_dual], dim=-1) net_outputs.append(mismatch_fea) net_outputs.append(post_qefv) net_outputs.append(pre_qefv) net_outputs.append(post_qefv_dual) net_outputs.append(pre_qefv_dual) net_outputs = torch.cat(net_outputs, dim=-1) # -> B x K mask = 1 - torch.eq(sample['target'], 1).unsqueeze(dim=-1).type_as(net_outputs) mask = mask.repeat(1, 1, net_outputs.shape[-1]) net_outputs = net_outputs * mask return net_outputs # NMT features mt_qefv_prim = get_gap_fea() xml_model.eval() with torch.no_grad(): tensor = xml_model('fwd', x=sample['xml_word_ids'], lengths=sample['xml_lengths'], langs=sample['langs'].cuda(), causal=False).contiguous() if self.xml_tgt_only: ''' only extract target features for XLM ''' xml_src_lengths = sample['xml_src_lengths'] xml_tgt_lengths = sample['xml_tgt_lengths'] tensor_ = tensor.transpose(0, 1) tensor = torch.unbind(tensor_) xml_word_ids = sample['xml_word_ids'].transpose(0, 1) xml_word_ids = torch.unbind(xml_word_ids) max_tgt_length = max(xml_tgt_lengths) max_tgt_length = max(max_tgt_length, mt_qefv_prim.shape[1]) xml_tensor = torch.FloatTensor(len(tensor), max_tgt_length, 1024).fill_(0) xml_tgt_word_ids = torch.LongTensor(len(tensor), max_tgt_length).fill_(2) for i, (t, tgt_word_id) in enumerate(zip(tensor, xml_word_ids)): start = xml_src_lengths[i] + 3 end = start + xml_tgt_lengths[i] selected_tensor = t[start:end] xml_tensor[i, :selected_tensor.shape[0]] = selected_tensor selected_tgt_word_ids = tgt_word_id[start:end] xml_tgt_word_ids[ i, :selected_tensor.shape[0]] = selected_tgt_word_ids mask = torch.ne(xml_tgt_word_ids, 2).cuda().float() target_embeding = F.embedding(xml_tgt_word_ids.cuda(), xml_model.pred_layer.proj.weight) xml_tensor = xml_tensor.cuda() * ( mask.unsqueeze(-1).expand_as(xml_tensor)) * target_embeding pre_qefv = torch.mul(xml_tensor, target_embeding) post_qefv = xml_tensor pre_qefv = torch.tanh(estimator.reduce_dim(pre_qefv)) post_qefv = torch.tanh(estimator.reduce_dim(post_qefv)) paded_tensor = torch.FloatTensor(xml_tensor.shape[0], xml_tensor.shape[1], 4).fill_(0).cuda() prob = xml_model.pred_layer.proj(xml_tensor.cuda()) prob = utils.log_softmax(prob, dim=-1) target = xml_tgt_word_ids.unsqueeze(-1).to(mt_qefv_prim.device) i_gt = prob.gather(dim=-1, index=target) i_max, i_argmax = prob.max(dim=-1, keepdim=True) i_equal = torch.eq(i_argmax, target).type_as(i_gt) i_gap = i_max - i_gt gap_fea = torch.cat([ i_gt, i_max, i_equal, i_gap, ], dim=-1) xml_qefv = torch.cat( [gap_fea, post_qefv, pre_qefv, post_qefv, pre_qefv], dim=-1) else: ''' extract both source and target features for XLM ''' xml_tensor = tensor.transpose(0, 1) xml_word_ids = sample['xml_word_ids'].transpose(0, 1) mask = torch.ne(xml_word_ids, 2).cuda().float() target_embeding = F.embedding(xml_word_ids.cuda(), xml_model.pred_layer.proj.weight) xml_tensor = xml_tensor.cuda() * ( mask.unsqueeze(-1).expand_as(xml_tensor)) * target_embeding pre_qefv = torch.mul(xml_tensor, target_embeding) post_qefv = xml_tensor pre_qefv = torch.tanh(estimator.reduce_dim(pre_qefv)) post_qefv = torch.tanh(estimator.reduce_dim(post_qefv)) paded_tensor = torch.FloatTensor(xml_tensor.shape[0], xml_tensor.shape[1], 4).fill_(0).cuda() prob = xml_model.pred_layer.proj(tensor.transpose(0, 1)) prob = utils.log_softmax(prob, dim=-1) target = xml_word_ids.unsqueeze(-1).to(mt_qefv_prim.device) i_gt = prob.gather(dim=-1, index=target) i_max, i_argmax = prob.max(dim=-1, keepdim=True) i_equal = torch.eq(i_argmax, target).type_as(i_gt) i_gap = i_max - i_gt gap_fea = torch.cat([ i_gt, i_max, i_equal, i_gap, ], dim=-1) xml_qefv = torch.cat( [gap_fea, post_qefv, pre_qefv, post_qefv, pre_qefv], dim=-1) # .transpose(0, 1) ter_prediction = 0 if self.share_xml_dict: ter_prediction = estimator( torch.cat([xml_qefv, mt_qefv_prim], dim=-1)) else: if self.estimator_xml_only: xml_qefv = torch.cat([ pre_qefv, pre_qefv, pre_qefv, pre_qefv, pre_qefv, paded_tensor ], dim=-1) ter_prediction += estimator(xml_qefv) else: xml_qefv = torch.cat( [xml_qefv, xml_qefv, xml_qefv, xml_qefv, xml_qefv], dim=-1) ter_prediction += estimator.combine_forward( mt_qefv_prim, xml_qefv) return ter_prediction
def forward(self, x): return F.embedding(x, self.W_())
def forward(self, batch_nodes, node_features, fw_adj, bw_adj, idx_sql_seqs, sql_seqs_lens): # import pdb; pdb.set_trace() seqs_embedded = self.embed_layer(idx_sql_seqs) seqs_embedded = self.seq_embedding_dropout(seqs_embedded) seqs_packed = pack_padded_sequence(seqs_embedded, sql_seqs_lens, batch_first=True, enforce_sorted=False) seqs_encoding, _ = self.seq_encoder(seqs_packed) seqs_encoding, _ = pad_packed_sequence(seqs_encoding, batch_first=True) seqs_encoding_mask = (idx_sql_seqs == 0).bool() # [b, t] batch_size, seq_len = node_features.size() output = self.embed_layer(node_features) node_output, _ = self.node_feature_encoder(output) # features are short, no need to pack node_embedding = node_output[:,-1,:] # take the last timestep as initial node embedding ? fw_hidden = node_embedding bw_hidden = node_embedding.clone() embedded_node_rep = torch.cat([node_embedding.clone(), torch.zeros([1, self.hidden_size]).to(conf.device)], dim=0) # add a row of zero for PAD node fw_sampled_neighbors = fw_adj[:-1,:self.sample_size_per_layer] # ignore PAD node bw_sampled_neighbors = bw_adj[:-1,:self.sample_size_per_layer] # fw_sampled_neighbors_len = fw_adj.size(0) # bw_sampled_neighbors_len = bw_adj.size(0) for l in range(self.sample_layer_size): dim_mul = 1 if l == 0 else 2 # because output is concatenated fw_aggregator = self.fw_aggregators[min(l, self.max_unique_sample_layers-1)] if l == 0: # the PAD node will get zero embeddings neighbor_hiddens = F.embedding(fw_sampled_neighbors, embedded_node_rep) else: neighbor_hiddens = F.embedding(fw_sampled_neighbors, torch.cat([fw_hidden, torch.zeros([1, dim_mul * self.hidden_size]).to(conf.device)], dim=0)) fw_hidden = fw_aggregator(fw_hidden, neighbor_hiddens) if self.graph_encode_direction == "bi": bw_aggregator = self.bw_aggregators[min(l, self.max_unique_sample_layers-1)] if l == 0: neighbor_hiddens = F.embedding(bw_sampled_neighbors, embedded_node_rep) else: neighbor_hiddens = F.embedding(bw_sampled_neighbors, torch.cat([bw_hidden, torch.zeros([1, dim_mul * self.hidden_size]).to(conf.device)], dim=0)) # the PAD node will get zero embeddings bw_hidden = bw_aggregator(bw_hidden, neighbor_hiddens) # Graph Embedding: max pooling if self.graph_encode_direction == "bi": hidden = torch.cat([fw_hidden, bw_hidden], axis=-1) else: hidden = fw_hidden hidden = F.relu(hidden) # [b, out_h] out_hidden_size = hidden.size(1) num_graphs = len(batch_nodes) max_len = max([len(g) for g in batch_nodes]) graph_hidden = torch.zeros([num_graphs, max_len, out_hidden_size]).to(conf.device) for i, g_node_idxs in enumerate(batch_nodes): graph_hidden[i,:len(g_node_idxs),:] = hidden[g_node_idxs[0]:g_node_idxs[-1]+1] graph_embedding, _ = torch.max(graph_hidden, dim=1) # [num_g, out_h] return graph_hidden, graph_embedding, max_len, seqs_encoding, seqs_encoding_mask
def forward(ctx, one_hot, emb, dim=-1): assert dim == -1 _, idx = torch.max(one_hot, dim=-1) out = F.embedding(idx, emb) ctx.save_for_backward(one_hot, idx, emb, out) return out
def forward(self, x): x = F.embedding(x, self.weights, padding_idx=self.padding_idx) return x
def lookup(weight, inputs): idxes_t = input_idx(inputs) return F.embedding(idxes_t, weight)
def i(self, x, pos=True): x = F.embedding(x, self.out.weight * self.scale) if pos: x = x + positional_encodings_like(x) return x
user_ids = u[step * batch_size:(step + 1) * batch_size] user_feats = { key: value[user_ids] for key, value in syn.user_feats.items() } user_feats = syn.to_device(user_feats) logits = nominator(user_feats, val_item_feats) if args.loss_type == "loss_ce": loss = loss_ce(logits, item_ids, item_probs) elif args.loss_type == "loss_ips": loss = loss_ips(logits, item_ids, item_probs, upper_limit=10) elif args.loss_type == "loss_2s": batch_ranker_logits = F.embedding( torch.LongTensor(user_ids).to(device), ranker_logits) loss = loss_2s(logits, item_ids, item_probs, batch_ranker_logits, upper_limit=10, alpha=args.alpha) else: raise NotImplementedError("{} not supported.".format( args.loss_type)) opt.zero_grad() loss.backward() opt.step() with torch.no_grad():
def embedding(x, embed_param): amp_embed = F.embedding(x, embed_param[0], padding_idx=0) pha_embed = F.embedding(x, embed_param[1], padding_idx=0) mix_embed = F.embedding(x, embed_param[2], padding_idx=0) return (amp_embed, pha_embed), mix_embed
def get_tgt_inp(tgt, time_step): word_embed = F.embedding(tgt.type(src.type()), word_embedding) * self.scale pos_embed = pos_embedding[time_step, :].reshape(1, 1, -1) return word_embed + tgt_lang_embed + pos_embed
def get_entity(self, emb_id): t0 = time.time() entity = F.embedding(emb_id, self.model.emb.data, sparse=True) t1 = time.time() # print('get: ' + str(t1 - t0)) return entity
def forward(self, indices): mu = F.embedding(indices, self.weight) std = F.embedding(indices, self.spread) return self.reparameterize(mu, std)
def forward_frozen(self, x): return F.embedding(x, self.weight_mu, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse)
def init_hidden(self, batch_size: int, init_batch: List[List[Relation]]) -> HiddenState: """ Initialize stuff shared within whole batch, and return initial hidden states. topics: (batch_size), topic ID of each example. batch_rels: facts for each topic. Each a list of tuples (rel_id, obj_id, alias). - rel_id: {-1: NaF, -2: anchor, -3: topic_itself} - obj_id: -1: NaF / UNK - obj_alias: list of aliases for the relation. """ batch_rels = init_batch """ Fact Embeddings """ max_facts = max([len(r) for r in batch_rels]) num_facts = torch.tensor([len(r) for r in batch_rels]) # insert NaF as the first vector, and add 1 to all rel_ids # this is because we need to compute the CE loss including NaF, and -1 can't be used as target # fact_embeds: (batch_size, max_facts + 1, fact_embed_dim) fact_embeds = torch.stack([ torch.cat([ torch.cat( [ torch.stack([self.naf_vec[:self._kb_embed_dim]] + [ self.relation_vec[rel.rel_typ] if rel.rel_typ >= 0 else self.special_rel_vecs[rel.rel_typ + 1] for rel in relations ]), torch.stack( [self.naf_vec[self._kb_embed_dim:]] + [ self.entity_vec[rel.obj_id] if rel.obj_id != -1 else self.unk_entity_vec for rel in relations ] ) # batches with fewer relations are padded to enable vectorized operations ], dim=1), torch.zeros(max_facts + 1 - num_facts[idx], self._fact_embed_dim, device=self.device) ]) for idx, relations in enumerate(batch_rels) ]) knowledge_embed = None if self._use_knowledge_embed: # knowledge_embed: (batch_size, fact_embed_dim) knowledge_embed = torch.stack( [torch.mean(e, dim=0) for e in fact_embeds]) alias_word_cnt = None if self._mask_invalid_pos: # alias_word_cnt: list(batch_size) of list(num_facts) of n_alias alias_word_cnt = [[[ len(self.alias_list[alias].split()) for alias in rel.obj_alias ] for rel in relations] for relations in batch_rels] alias_vecs = alias_masks = None if self._alias_disamb is AliasDisamb.FastText: max_aliases = [ max(len(rel.obj_alias) for rel in relations) for relations in batch_rels ] # alias_vecs: list(batch_size) of (max_facts + 1, max_aliases, alias_vec_dim) # alias_masks: list(batch_size) of (max_facts + 1, max_aliases) # the masks set invalid positions to -inf, used during log_softmax alias_vecs = [] alias_masks = [] for b, relations in enumerate(batch_rels): aliases = torch.zeros(len(relations), max_aliases[b] + 1, device=self.device, dtype=torch.long) mask = torch.full((len(relations), max_aliases[b] + 1), -math.inf, device=self.device) for idx, rel in enumerate(relations): aliases[idx, :len(rel.obj_alias)] = torch.tensor( rel.obj_alias, device=self.device) mask[idx, :len(rel.obj_alias)] = 0 vectors = F.embedding(aliases, self.alias_vec) alias_vecs.append(vectors) alias_masks.append(mask) self._cache = { 'num_facts': num_facts, 'fact_embeds': fact_embeds, 'knowledge_embed': knowledge_embed, 'alias_word_cnt': alias_word_cnt, 'alias_vecs': alias_vecs, 'alias_masks': alias_masks, } return self.rnn.init_hidden(batch_size)
def decode_model(args, model, dev, evaluate=True, decoding_path=None, names=None, maxsteps=None): args.logger.info("decoding, f_size={}, beam_size={}, alpha={}".format( args.f_size, args.beam_size, args.alpha)) dev.train = False # make iterator volatile=True if maxsteps is None: progressbar = tqdm(total=sum([1 for _ in dev]), desc='start decoding') else: progressbar = tqdm(total=maxsteps, desc='start decoding') model.eval() if decoding_path is not None: handles = [ open(os.path.join(decoding_path, name), 'w') for name in names ] corpus_size = 0 src_outputs, trg_outputs, dec_outputs, timings = [], [], [], [] decoded_words, target_words, decoded_info = 0, 0, 0 attentions = None pad_id = model.decoder[0].field.vocab.stoi['<pad>'] eos_id = model.decoder[0].field.vocab.stoi['<eos>'] curr_time = 0 cum_bs = 0 for iters, dev_batch in enumerate(dev): if iters > maxsteps: args.logger.info('complete {} steps of decoding'.format(maxsteps)) break start_t = time.time() # encoding inputs, input_masks, \ targets, target_masks, \ sources, source_masks, \ encoding, batch_size = model.quick_prepare(dev_batch) cum_bs += batch_size # for now if type(model) is Transformer: all_decodings = [] decoder_inputs, decoder_masks = inputs, input_masks decoding = model(encoding, source_masks, decoder_inputs, decoder_masks, beam=args.beam_size, alpha=args.alpha, \ decoding=True, feedback=attentions) all_decodings.append(decoding) elif type(model) is FastTransformer: decoder_inputs, _, decoder_masks = \ model.prepare_initial(encoding, sources, source_masks, input_masks,\ N=args.f_size) batch_size, src_len, hsize = encoding[0].size() all_decodings = [] prev_dec_output = None iter_ = 0 while True: iter_num = min(iter_, args.num_shared_dec - 1) next_iter = min(iter_ + 1, args.num_shared_dec - 1) decoding, out, probs = model(encoding, source_masks, decoder_inputs, decoder_masks, decoding=True, return_probs=True, iter_=iter_num) all_decodings.append(decoding) thedecoder = model.decoder[iter_num] logits = thedecoder.out(out) _, argmax = torch.max(logits, dim=-1) decoder_inputs = F.embedding( argmax, model.decoder[next_iter].out.weight * math.sqrt(args.d_model)) if args.sum_out_and_emb: decoder_inputs += out iter_ += 1 if iter_ == args.valid_repeat_dec: break used_t = time.time() - start_t curr_time += used_t real_mask = 1 - ((decoding.data == eos_id) + (decoding.data == pad_id)).float() outputs = [ model.output_decoding(d) for d in [('src', sources), ('trg', targets), ('trg', decoding)] ] all_dec_outputs = [ model.output_decoding(d) for d in [('trg', all_decodings[ii]) for ii in range(len(all_decodings))] ] corpus_size += batch_size src_outputs += outputs[0] trg_outputs += outputs[1] dec_outputs += outputs[-1] """ for sent_i in range(len(outputs[0])): print ('SRC') print (outputs[0][sent_i]) print ('TRG') print (outputs[1][sent_i]) for ii in range(len(all_decodings)): print ('DEC iter {}'.format(ii)) print (all_dec_outputs[ii][sent_i]) print ('---------------------------') """ timings += [used_t] if decoding_path is not None: for s, t, d in zip(outputs[0], outputs[1], outputs[2]): s, t, d = s.replace('@@ ', ''), t.replace('@@ ', ''), d.replace('@@ ', '') print(s, file=handles[0], flush=True) print(t, file=handles[1], flush=True) print(d, file=handles[2], flush=True) print(curr_time / float(cum_bs) * 1000) #progressbar.update(1) #progressbar.set_description('finishing sentences={}/batches={}, speed={} sec/batch'.format(corpus_size, iters, curr_time / (1 + iters))) if evaluate: corpus_bleu = computeBLEU(dec_outputs, trg_outputs, corpus=True, tokenizer=tokenizer) #args.logger.info("The dev-set corpus BLEU = {}".format(corpus_bleu)) print("The dev-set corpus BLEU = {}".format(corpus_bleu))
def get_gap_fea(): ''' get NMT features ''' predictor.eval() encoder = predictor.encoder encoder_out = torch.zeros( (sample['net_input']['src_tokens'].shape[1], sample['net_input']['src_tokens'].shape[0], encoder.output_embed_dim)).to(sample['target'].device) encoder_padding_mask = sample['net_input']['src_tokens'].eq( encoder.embed_positions.padding_idx) if not encoder_padding_mask.any(): encoder_padding_mask = None encoder_out = encoder.encode(sample['net_input']['src_tokens'], encoder_out={ 'encoder_out': encoder_out, 'encoder_padding_mask': encoder_padding_mask }) net_outputs = [] i_equals = [] for i in range(k): decoder = predictor.decoder prev_output_tokens_k = sample['net_input'][ 'prev_output_tokens'].clone() assert not prev_output_tokens_k.requires_grad prev_output_tokens_k[:, 0] = self.expert_index(i) # model derived features and dual model features net_output = predictor.decoder(prev_output_tokens_k, encoder_out) # B x T x dic_size lprobs = predictor.get_normalized_probs(net_output, log_probs=True) target = sample['target'] co_attn = torch.zeros( (sample['target'].shape[1], sample['target'].shape[0], predictor.encoder.output_embed_dim)).to( sample['target'].device) encoder_padding_mask = sample['target'].eq( predictor.encoder.embed_positions.padding_idx) if not encoder_padding_mask.any(): encoder_padding_mask = None enc_out_dual = decoder.encode(sample['target'], encoder_out={ 'encoder_out': co_attn, 'encoder_padding_mask': encoder_padding_mask }) enc_out_dual = enc_out_dual['encoder_out'].transpose(0, 1) lprobs_dual = decoder.output_layer(enc_out_dual) lprobs_dual = utils.log_softmax(lprobs_dual, dim=-1) target_embeding = F.embedding( target, predictor.decoder.embed_tokens.weight) last_output = net_output[1]['last_output'] * target_embeding pre_qefv = torch.mul(last_output, target_embeding) post_qefv = last_output pre_qefv_dual = enc_out_dual * target_embeding * target_embeding post_qefv_dual = enc_out_dual * target_embeding # mismatch features target = target.unsqueeze(-1) i_gt = lprobs.gather(dim=-1, index=target) i_max, i_argmax = lprobs.max(dim=-1, keepdim=True) i_equal = torch.eq(i_argmax, target).type_as(i_gt) i_equals.append(i_equal) i_gap = i_max - i_gt # i_gt_dual = lprobs_dual.gather(dim=-1, index=target) # i_max_dual, i_argmax_dual = lprobs_dual.max(dim=-1, keepdim=True) mismatch_fea = torch.cat([ i_gt, i_max, i_equal, i_gap, ], dim=-1) # i_gt_dual, i_max_dual, i_equal_dual, i_gap_dual, # i_gt-i_gt_dual, i_max-i_max_dual, i_equal-i_equal_dual, i_gap-i_gap_dual], dim=-1) net_outputs.append(mismatch_fea) net_outputs.append(post_qefv) net_outputs.append(pre_qefv) net_outputs.append(post_qefv_dual) net_outputs.append(pre_qefv_dual) net_outputs = torch.cat(net_outputs, dim=-1) # -> B x K mask = 1 - torch.eq(sample['target'], 1).unsqueeze(dim=-1).type_as(net_outputs) mask = mask.repeat(1, 1, net_outputs.shape[-1]) net_outputs = net_outputs * mask return net_outputs
def log_prob(self, input, precompute=None): """ Returns a [batch_size x z_dim] log-probability of input given state z """ emit_prob, = precompute return F.embedding(input, emit_prob)
def embed_code(self, embed_id): return F.embedding(embed_id, self.embed.transpose(0, 1))
def forward(self, query, key, value, key_padding_mask=None, incremental_state=None, need_weights=True, static_kv=False, attn_mask=None): """Input shape: Time x Batch x Channel Self-attention can be implemented by passing in the same arguments for query, key and value. Timesteps can be masked by supplying a T x T mask in the `attn_mask` argument. Padding elements can be excluded from the key by passing a binary ByteTensor (`key_padding_mask`) with shape: batch x src_len, where padding elements are indicated by 1s. """ qkv_same = query.data_ptr() == key.data_ptr() == value.data_ptr() kv_same = key.data_ptr() == value.data_ptr() tgt_len, bsz, embed_dim = query.size() assert embed_dim == self.embed_dim assert list(query.size()) == [tgt_len, bsz, embed_dim] assert key.size() == value.size() if incremental_state is not None: saved_state = self._get_input_buffer(incremental_state) if 'prev_key' in saved_state: # previous time steps are cached - no need to recompute # key and value if they are static if static_kv: assert kv_same and not qkv_same key = value = None else: saved_state = None if qkv_same: # self-attention q, k, v = self.in_proj_qkv(query) elif kv_same: # encoder-decoder attention q = self.in_proj_q(query) if key is None: assert value is None k = v = None else: k, v = self.in_proj_kv(key) else: q = self.in_proj_q(query) k = self.in_proj_k(key) v = self.in_proj_v(value) q *= self.scaling if self.bias_k is not None: assert self.bias_v is not None k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) if attn_mask is not None: attn_mask = torch.cat( [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) if key_padding_mask is not None: key_padding_mask = torch.cat([ key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1) ], dim=1) q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) if k is not None: k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) if v is not None: v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) if saved_state is not None: # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) if 'prev_key' in saved_state: prev_key = saved_state['prev_key'].view( bsz * self.num_heads, -1, self.head_dim) if static_kv: k = prev_key else: k = torch.cat((prev_key, k), dim=1) if 'prev_value' in saved_state: prev_value = saved_state['prev_value'].view( bsz * self.num_heads, -1, self.head_dim) if static_kv: v = prev_value else: v = torch.cat((prev_value, v), dim=1) saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim) saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim) self._set_input_buffer(incremental_state, saved_state) src_len = k.size(1) if key_padding_mask is not None: assert key_padding_mask.size(0) == bsz assert key_padding_mask.size(1) == src_len if self.add_zero_attn: src_len += 1 k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) if attn_mask is not None: attn_mask = torch.cat( [attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) if key_padding_mask is not None: key_padding_mask = torch.cat([ key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask) ], dim=1) relative_positions_matrix = self._generate_relative_positions_matrix( src_len, self.max_relative_length, incremental_state) if self.k_only: relation_keys = F.embedding( relative_positions_matrix.long().cuda(), self.relative_position_keys) else: relation_keys = F.embedding( relative_positions_matrix.long().cuda(), self.relative_position_keys) relation_values = F.embedding( relative_positions_matrix.long().cuda(), self.relative_position_values) relative_attn_weights = self._relative_attention_inner(q, k, relation_keys, transpose=True) assert list(relative_attn_weights.size()) == [ bsz * self.num_heads, tgt_len, src_len ] if attn_mask is not None: attn_mask = attn_mask.unsqueeze(0) if self.onnx_trace: attn_mask = attn_mask.repeat(relative_attn_weights.size(0), 1, 1) relative_attn_weights += attn_mask if key_padding_mask is not None: # don't attend to padding symbols relative_attn_weights = relative_attn_weights.view( bsz, self.num_heads, tgt_len, src_len) if self.onnx_trace: relative_attn_weights = torch.where( key_padding_mask.unsqueeze(1).unsqueeze(2), torch.Tensor([float("-Inf")]), relative_attn_weights.float()).type_as( relative_attn_weights) else: relative_attn_weights = relative_attn_weights.float( ).masked_fill( key_padding_mask.unsqueeze(1).unsqueeze(2), float('-inf'), ).type_as(relative_attn_weights ) # FP16 support: cast to float and back relative_attn_weights = relative_attn_weights.view( bsz * self.num_heads, tgt_len, src_len) relative_attn_weights = utils.softmax( relative_attn_weights, dim=-1, onnx_trace=self.onnx_trace, ).type_as(relative_attn_weights) relative_attn_weights = F.dropout(relative_attn_weights, p=self.dropout, training=self.training) # key only mode if self.k_only: attn = torch.bmm(relative_attn_weights, v) # original implementation else: attn = self._relative_attention_inner(relative_attn_weights, v, relation_values, transpose=False) #attn = torch.bmm(relative_attn_weights, v) assert list( attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] if (self.onnx_trace and attn.size(1) == 1): # when ONNX tracing a single decoder step (sequence length == 1) # the transpose is a no-op copy before view, thus unnecessary attn = attn.contiguous().view(tgt_len, bsz, embed_dim) else: attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) attn = self.out_proj(attn) if need_weights: # average attention weights over heads relative_attn_weights = relative_attn_weights.view( bsz, self.num_heads, tgt_len, src_len) relative_attn_weights = relative_attn_weights.sum( dim=1) / self.num_heads else: relative_attn_weights = None return attn, relative_attn_weights
def forward(self, batch_size, index, feats, values): # type: (int, Tensor, Tensor, Tensor) -> Tensor batch_first = F.embedding(feats, self.weights) batch_second = F.embedding(feats, self.embedding) return self.forward_(batch_size, index, feats, values, self.bias, batch_first, batch_second, self.mats)
def get(self, input_: LongTensorType) -> FloatTensorType: return F.embedding( input_, self.weight, max_norm=self.max_norm, sparse=True, )
def decode(self, enc_hiddens: torch.Tensor, enc_masks: torch.Tensor, dec_init_state: Tuple[torch.Tensor, torch.Tensor], target_padded: torch.Tensor) -> torch.Tensor: """Compute combined output vectors for a batch. @param enc_hiddens (Tensor): Hidden states (b, src_len, h*2), where b = batch size, src_len = maximum source sentence length, h = hidden size. @param enc_masks (Tensor): Tensor of sentence masks (b, src_len), where b = batch size, src_len = maximum source sentence length. @param dec_init_state (tuple(Tensor, Tensor)): Initial state and cell for decoder @param target_padded (Tensor): Gold-standard padded target sentences (tgt_len, b), where tgt_len = maximum target sentence length, b = batch size. @returns combined_outputs (Tensor): combined output tensor (tgt_len, b, h), where tgt_len = maximum target sentence length, b = batch_size, h = hidden size """ # Chop of the <END> token for max length sentences. target_padded = target_padded[:-1] # Initialize the decoder state (hidden and cell) dec_state = dec_init_state # Initialize previous combined output vector o_{t-1} as zero batch_size = enc_hiddens.size(0) o_prev = torch.zeros(batch_size, self.hidden_size, device=self.device) # Initialize a list we will use to collect the combined output o_t on each step combined_outputs = [] ### YOUR CODE HERE (~9 Lines) ### TODO: ### 1. Apply the attention projection layer to `enc_hiddens` to obtain `enc_hiddens_proj`, ### which should be shape (b, src_len, h), ### where b = batch size, src_len = maximum source length, h = hidden size. ### This is applying W_{attProj} to h^enc, as described in the PDF. ### 2. Construct tensor `Y` of target sentences with shape (tgt_len, b, e) using the target model embeddings. ### where tgt_len = maximum target sentence length, b = batch size, e = embedding size. ### 3. Use the torch.split function to iterate over the time dimension of Y. ### Within the loop, this will give you Y_t of shape (1, b, e) where b = batch size, e = embedding size. ### - Squeeze Y_t into a tensor of dimension (b, e). ### - Construct Ybar_t by concatenating Y_t with o_prev. ### - Use the step function to compute the the Decoder's next (cell, state) values ### as well as the new combined output o_t. ### - Append o_t to combined_outputs ### - Update o_prev to the new o_t. ### 4. Use torch.stack to convert combined_outputs from a list length tgt_len of ### tensors shape (b, h), to a single tensor shape (tgt_len, b, h) ### where tgt_len = maximum target sentence length, b = batch size, h = hidden size. ### ### Note: ### - When using the squeeze() function make sure to specify the dimension you want to squeeze ### over. Otherwise, you will remove the batch dimension accidentally, if batch_size = 1. ### ### Use the following docs to implement this functionality: ### Zeros Tensor: ### https://pytorch.org/docs/stable/torch.html#torch.zeros ### Tensor Splitting (iteration): ### https://pytorch.org/docs/stable/torch.html#torch.split ### Tensor Dimension Squeezing: ### https://pytorch.org/docs/stable/torch.html#torch.squeeze ### Tensor Concatenation: ### https://pytorch.org/docs/stable/torch.html#torch.cat ### Tensor Stacking: ### https://pytorch.org/docs/stable/torch.html#torch.stack enc_hiddens_proj = self.att_projection(enc_hiddens)#[b,src_len,h] tgt_padded_embedding= F.embedding(target_padded, self.model_embeddings.target.weight) Y = tgt_padded_embedding looper = torch.split(Y,1,0) for slice in looper: Y_t = torch.squeeze(slice,dim=0)#1,b,e->b,e Ybar_t = torch.cat((Y_t,o_prev),1)#[b,e]+[b,e]->[b,2e] dec_state, o_t, e_t = self.step(Ybar_t,dec_state,enc_hiddens,enc_hiddens_proj,enc_masks) combined_outputs.append(o_t) o_prev = o_t combined_outputs = torch.stack(combined_outputs,dim=0) ### END YOUR CODE return combined_outputs#[t_len,b,e]
set_one_emb = sp.load_npz(emb_path + '/set_one.npz').toarray() set_two_emb = sp.load_npz(emb_path + '/set_two.npz').toarray() elif method == 'hegan': emb_path = '/Users/tian/Documents/P4_Bipartite_Graph_Representation/journel_version/baselines/HeGAN/results/' \ + dataset + '/recommendation' set_one_emb = sp.load_npz(emb_path + '/set_one.npz').toarray() set_two_emb = sp.load_npz(emb_path + '/set_two.npz').toarray() else: assert False, "Wrong model!" print("method: {}, dataset: {}".format(method, dataset)) train_edges = torch.tensor(dataloader.edges) set_one_emb = torch.FloatTensor(set_one_emb) set_two_emb = torch.FloatTensor(set_two_emb) set_one_pos = F.embedding(train_edges[:, 0], set_one_emb) set_two_pos = F.embedding(train_edges[:, 1], set_two_emb) # load the testing edges test_edges = torch.tensor(dataloader.label.toarray()) edges = torch.cat((train_edges, test_edges), dim=0) neg_ratio = 1 neg_row, neg_col = negative_sampling_edges(set_one_emb.shape[0], set_two_emb.shape[0], edges, neg_ratio) set_one_neg = F.embedding(neg_row, set_one_emb) set_two_neg = F.embedding(neg_col, set_two_emb) X_one = torch.cat((set_one_pos, set_one_neg), dim=0) X_two = torch.cat((set_two_pos, set_two_neg), dim=0) lb_pos = torch.ones(set_one_pos.shape[0])
def depth_r_score(self, rel: Tensor, arg1: Tensor, arg2: Tensor, facts: List[Tensor], entity_embeddings: Tensor, depth: int) -> Tensor: assert depth >= 0 if depth == 0: return self.model.score(rel, arg1, arg2, facts) batch_size, embedding_size = rel.shape[0], rel.shape[1] global_res = None mask = None new_hops_lst = self.hops_lst if self.R is not None: batch_rules_scores = torch.cat( [h.prior(rel).view(-1, 1) for h, _ in self.hops_lst], 1) topk, indices = torch.topk(batch_rules_scores, self.R) # [R x E] rule_heads = torch.cat([h.head for h, _ in self.hops_lst], dim=0) rule_body1s = torch.cat( [h.memory_lst[0] for h, _ in self.hops_lst], dim=0) rule_body2s = torch.cat( [h.memory_lst[1] for h, _ in self.hops_lst], dim=0) kernel = self.hops_lst[0][0].kernel new_rule_heads = F.embedding(indices, rule_heads) new_rule_body1s = F.embedding(indices, rule_body1s) new_rule_body2s = F.embedding(indices, rule_body2s) # print(new_rule_heads.shape[1], self.R) assert new_rule_heads.shape[1] == self.R new_hops_lst = [] for i in range(new_rule_heads.shape[1]): r = GNTPReformulator( kernel=kernel, head=new_rule_heads[:, i, :], body=[new_rule_body1s[:, i, :], new_rule_body2s[:, i, :]]) new_hops_lst += [(r, False)] # import sys # sys.exit(0) # mask = torch.zeros_like(batch_rules_scores).scatter_(1, indices, torch.ones_like(topk)) # for hops_generator, is_reversed in self.hops_lst: # for rule_idx, (hops_generator, is_reversed) in enumerate(self.hops_lst): for rule_idx, (hops_generator, is_reversed) in enumerate(new_hops_lst): sources, scores = arg1, None # XXX prior = hops_generator.prior(rel) if prior is not None: if mask is not None: prior = prior * mask[:, rule_idx] if (prior != 0.0).sum() == 0: continue scores = prior # scores = hops_generator.prior(rel) hop_rel_lst = hops_generator(rel) nb_hops = len(hop_rel_lst) for hop_idx, hop_rel in enumerate(hop_rel_lst, start=1): # [B * S, E] sources_2d = sources.view(-1, embedding_size) nb_sources = sources_2d.shape[0] nb_branches = nb_sources // batch_size hop_rel_3d = hop_rel.view(-1, 1, embedding_size).repeat( 1, nb_branches, 1) hop_rel_2d = hop_rel_3d.view(-1, embedding_size) if hop_idx < nb_hops: # [B * S, K], [B * S, K, E] if is_reversed: z_scores, z_emb = self.r_hop(hop_rel_2d, None, sources_2d, facts, entity_embeddings, depth=depth - 1) else: z_scores, z_emb = self.r_hop(hop_rel_2d, sources_2d, None, facts, entity_embeddings, depth=depth - 1) k = z_emb.shape[1] # [B * S * K] z_scores_1d = z_scores.view(-1) # [B * S * K, E] z_emb_2d = z_emb.view(-1, embedding_size) # [B * S * K, E] sources = z_emb_2d # [B * S * K] scores = z_scores_1d if scores is None \ else self._tnorm(z_scores_1d, scores.view(-1, 1).repeat(1, k).view(-1)) else: # [B, S, E] arg2_3d = arg2.view(-1, 1, embedding_size).repeat( 1, nb_branches, 1) # [B * S, E] arg2_2d = arg2_3d.view(-1, embedding_size) # [B * S] if is_reversed: z_scores_1d = self.r_score(hop_rel_2d, arg2_2d, sources_2d, facts, entity_embeddings, depth=depth - 1) else: z_scores_1d = self.r_score(hop_rel_2d, sources_2d, arg2_2d, facts, entity_embeddings, depth=depth - 1) scores = z_scores_1d if scores is None else self._tnorm( z_scores_1d, scores) if scores is not None: scores_2d = scores.view(batch_size, -1) res, _ = torch.max(scores_2d, dim=1) else: res = self.model.score(rel, arg1, arg2, facts) global_res = res if global_res is None else torch.max( global_res, res) return global_res
def _embedding_fn(x, y): x_emb = F.embedding(x, self._embedding) y_emb = F.embedding(y, self._pos_embedding) return x_emb * self._emb_dim ** 0.5 + y_emb
def forward(self, batch_size, index, feats, values): # type: (int, Tensor, Tensor, Tensor) -> Tensor weight = F.embedding(feats, self.weights) bias = self.bias return self.forward_(batch_size, index, feats, values, bias, weight)
def embed_target(self, inp): return self.act(F.embedding(inp, self.characterProjection.weight))