def forward(self, feats, edges=None): # allocate memory dtype, device = feats.dtype, feats.device edges = edges.view(-1, 3) V, E = feats.size(0), edges.size(0) pooled_v_pos = torch.zeros(V, feats.shape[-3], feats.shape[-1], feats.shape[-1], dtype=dtype, device=device) pooled_v_neg = torch.zeros(V, feats.shape[-3], feats.shape[-1], feats.shape[-1], dtype=dtype, device=device) # pool positive edges pos_inds = torch.where(edges[:, 1] > 0) pos_v_src = torch.cat([edges[pos_inds[0], 0], edges[pos_inds[0], 2]]).long() pos_v_dst = torch.cat([edges[pos_inds[0], 2], edges[pos_inds[0], 0]]).long() pos_vecs_src = feats[pos_v_src.contiguous()] pos_v_dst = pos_v_dst.view(-1, 1, 1, 1).expand_as(pos_vecs_src).to(device) pooled_v_pos = torch.scatter_add(pooled_v_pos, 0, pos_v_dst, pos_vecs_src) # pool negative edges neg_inds = torch.where(edges[:, 1] < 0) neg_v_src = torch.cat([edges[neg_inds[0], 0], edges[neg_inds[0], 2]]).long() neg_v_dst = torch.cat([edges[neg_inds[0], 2], edges[neg_inds[0], 0]]).long() neg_vecs_src = feats[neg_v_src.contiguous()] neg_v_dst = neg_v_dst.view(-1, 1, 1, 1).expand_as(neg_vecs_src).to(device) pooled_v_neg = torch.scatter_add(pooled_v_neg, 0, neg_v_dst, neg_vecs_src) # update nodes features enc_in = torch.cat([feats, pooled_v_pos, pooled_v_neg], 1) out = self.encoder(enc_in) return out
def forward( # pylint: disable=missing-function-docstring self, segment_sizes: "torch.Tensor", features: "torch.Tensor", ) -> "torch.Tensor": n_seg = len(segment_sizes) encoded_features = self.encoder(features) segment_indices = torch.repeat_interleave( torch.arange(n_seg, device=features.device), segment_sizes.long(), ) n_dim = encoded_features.shape[1] segment_sum = torch.scatter_add( input=torch.zeros((n_seg, n_dim), dtype=encoded_features.dtype, device=features.device), dim=0, index=segment_indices.view(-1, 1).expand(-1, n_dim), src=encoded_features, ) out = self.norm(segment_sum) out = self.layer0(out) + out out = self.layer1(out) + out out = self.decoder(out).squeeze() out = self.sigmoid(out) return out
def label_weight(this, shape, label): weight = torch.zeros(shape).to(label.device) + 0.1 weight = torch.scatter_add(weight, 0, label, torch.ones_like(label).float()) weight = 1. / weight weight[-1] /= 200 return weight
def update_balance_profile( balance_dict, gate_top_k_idx, _gate_score_top_k, gate_context, layer_idx, num_expert, balance_strategy, ): c_e = torch.scatter_add( torch.zeros(num_expert, device=gate_top_k_idx.device), 0, gate_top_k_idx, torch.ones_like(gate_top_k_idx, dtype=torch.float), ) for key in metrics: balance_dict[key][layer_idx] = metrics[key](c_e) S = gate_top_k_idx.shape[0] if balance_strategy == "gshard": gate_score_all = gate_context m_e = torch.sum(F.softmax(gate_score_all, dim=1), dim=0) / S balance_dict["gshard_loss"][layer_idx] = torch.sum( c_e * m_e) / num_expert / S elif balance_strategy == "noisy": balance_dict["noisy_loss"][layer_idx] = gate_context
def tensor_indexing_ops(self): x = torch.randn(2, 4) y = torch.randn(4, 4) t = torch.tensor([[0, 0], [1, 0]]) mask = x.ge(0.5) i = [0, 1] return len( torch.cat((x, x, x), 0), torch.concat((x, x, x), 0), torch.conj(x), torch.chunk(x, 2), torch.dsplit(torch.randn(2, 2, 4), i), torch.column_stack((x, x)), torch.dstack((x, x)), torch.gather(x, 0, t), torch.hsplit(x, i), torch.hstack((x, x)), torch.index_select(x, 0, torch.tensor([0, 1])), x.index(t), torch.masked_select(x, mask), torch.movedim(x, 1, 0), torch.moveaxis(x, 1, 0), torch.narrow(x, 0, 0, 2), torch.nonzero(x), torch.permute(x, (0, 1)), torch.reshape(x, (-1, )), torch.row_stack((x, x)), torch.select(x, 0, 0), torch.scatter(x, 0, t, x), x.scatter(0, t, x.clone()), torch.diagonal_scatter(y, torch.ones(4)), torch.select_scatter(y, torch.ones(4), 0, 0), torch.slice_scatter(x, x), torch.scatter_add(x, 0, t, x), x.scatter_(0, t, y), x.scatter_add_(0, t, y), # torch.scatter_reduce(x, 0, t, reduce="sum"), torch.split(x, 1), torch.squeeze(x, 0), torch.stack([x, x]), torch.swapaxes(x, 0, 1), torch.swapdims(x, 0, 1), torch.t(x), torch.take(x, t), torch.take_along_dim(x, torch.argmax(x)), torch.tensor_split(x, 1), torch.tensor_split(x, [0, 1]), torch.tile(x, (2, 2)), torch.transpose(x, 0, 1), torch.unbind(x), torch.unsqueeze(x, -1), torch.vsplit(x, i), torch.vstack((x, x)), torch.where(x), torch.where(t > 0, t, 0), torch.where(t > 0, t, t), )
def calc_b_new(layer, b_tilde): with T.no_grad(): act_shift = T.max(T.abs(b_tilde)) # For b_tilde < 0 replace biases with b_tilde b_new = T.scatter(layer.bias, 0, T.where(b_tilde < 0)[0], b_tilde[b_tilde < 0]) # For b_tilde < 0 update biases b_new = T.scatter_add(-T.abs(b_new), 0, T.where(b_tilde < 0)[0], act_shift.expand_as(T.where(b_tilde < 0)[0])) # For b_tilde > 0 update biases b_new = T.scatter_add(b_new, 0, T.where(b_tilde > 0)[0], act_shift.expand_as(T.where(b_tilde > 0)[0])) return b_new, act_shift
def compute_context(self, weights, relpos): """ :param weights: (batsize, numheads, qlen, klen) :param relpos: (batsize, qlen, klen, 1) :return: # weighted sum over klen (batsize, numheads, qlen, dimperhead) """ ret = None batsize = weights.size(0) numheads = weights.size(1) qlen = weights.size(2) device = weights.device # Naive implementation builds matrices of (batsize, numheads, qlen, klen, dimperhead) # whereas normal transformer only (batsize, numheads, qlen, klen) and (batsize, numheads, klen, dimperhead) for n in range(relpos.size(-1)): relpos_ = relpos[:, :, :, n] # map relpos_ to compact integer space of unique relpos_ entries try: relpos_unique = relpos_.unique() except Exception as e: raise e mapper = torch.zeros( relpos_unique.max() + 1, device=device, dtype=torch.long ) # mapper is relpos_unique but the other way around mapper[relpos_unique] = torch.arange(0, relpos_unique.size(0), device=device).long() relpos_mapped = mapper[ relpos_] # (batsize, qlen, klen) but ids are from 0 to number of unique relposes # sum up the attention weights which refer to the same relpos id # scatter: src is weights, index is relpos_mapped[:, None, :, :] # scatter: gathered[batch, head, qpos, relpos_mapped[batch, head, qpos, kpos]] # += weights[batch, head, qpos, kpos] gathered = torch.zeros(batsize, numheads, qlen, relpos_unique.size(0), device=device) gathered = torch.scatter_add( gathered, -1, relpos_mapped[:, None, :, :].repeat(batsize, numheads, 1, 1), weights) # --> (batsize, numheads, qlen, numunique): summed attention weights # get embeddings and update ret embs = self.embv(relpos_unique).view( relpos_unique.size(0), numheads, -1) # (numunique, numheads, dimperhead) relposemb = torch.einsum("bhqn,nhd->bhqd", gathered, embs) if ret is None: ret = torch.zeros_like(relposemb) ret = ret + relposemb return ret
def compute_degree_distr(data_x, k): data_y = data_x data_x_len = len(data_x) mean_dist_a = torch.zeros(len(data_x), device=device) mean_dist_b = torch.zeros(len(data_x), device=device) batch_sz = 700 y_norm = (data_y**2).sum(-1).unsqueeze(0) data_y = data_y.t() degrees = torch.zeros(data_x_len, device=device) for i in range(0, data_x_len, batch_sz): j = min(data_x_len, i + batch_sz) x = data_x[i:j] x_norm = (x**2).sum(-1).unsqueeze(-1) cur_dist = x_norm + y_norm - 2 * torch.mm(x, data_y) del x_norm del x #top dist includes 0 top_dist, ranks = torch.topk(cur_dist, k + 1, largest=False) ones = torch.ones(j - i, k + 1, device=device) degrees = torch.scatter_add(degrees, dim=0, index=ranks.view(-1), src=ones.view(-1)) #mean_dist_a[i:j] = (top_dist/k).sum(-1) #mean_dist_b[i:j] = (cur_dist/(data_x_len-1)).sum(-1) distribution = torch.zeros(data_x_len // 3, device=device) ones = torch.ones(data_x_len, device=device) distribution = torch.scatter_add(distribution, dim=0, index=(degrees - 1).long(), src=ones) pdb.set_trace() return distribution
def forward_step(self, src_context_output, src_word_output, src_word_mask, KG_word_output, KG_word_seq, KG_word_mask, hidden, tgt_word, combine_knowledge = False): """ Args: src_context_output (FloatTensor) : (batch_size, context_size) src_word_output (FloatTensor) : (batch_size, src_word_len, word_size) src_word_mask (FloatTensor) : (batch_size, src_word_len) KG_word_output (FloatTensor) : (batch_size, KG_max_word_len, KG_word_size) KG_word_seq (LongTensor) : (batch_size, src_max_word_len) KG_word_mask (FloatTensor) : (batch_size, KG_word_len) last_hidden (FloatTensor) : (batch_size, rnn_size) tgt_word (LongTensor) : (batch_size) Regurns: hidden (FloatTensor) : (num_layer, batch_size, rnn_size) logit_word (FloatTensor) : (batch_size, num_vocab) """ batch_size = tgt_word.size(0) # last_hidden : [batch_size, rnn_size] last_hidden = hidden[-1] # attn_src_word : [batch_size, src_word_size] attn_src_word,_ = self.attention(last_hidden, src_word_output, src_word_output, src_word_mask) # attn_KG_word : [batch_size, src_KG_size] KG_attention_query = last_hidden + attn_src_word attn_KG_word, attn_KG_scores = self.attention(KG_attention_query, KG_word_output, KG_word_output, KG_word_mask) # tgt_word_emb = [batch_size, emb_size] tgt_word_emb = self.word_embedding(tgt_word) # rnn_inputs : [barch_size, emb_size + src_word_size] rnn_inputs = torch.cat([tgt_word_emb, attn_src_word, attn_KG_word], dim = 1) # rnn_output : [batch_size, rnn_size] ; hidden : [num_layer, batch_size, rnn_size] rnn_output, hidden = self.rnn_cell(rnn_inputs, hidden) # prob_word : [batch_size, num_vocab] prob_word = self.output(rnn_output) if combine_knowledge == True: # copy_dist : [batch_size, num_vocab] copy_prob = torch.zeros(batch_size, self.num_vocab, device = src_context_output.device) copy_prob = torch.scatter_add(input = copy_prob, dim = 1, index = KG_word_seq, src = attn_KG_scores.permute(1, 0)) # gen_dist_trans_input : [batch_size, emb_size + KG_rnn_size + rnn_size] gen_dist_trans_input = torch.cat([tgt_word_emb, attn_KG_word, rnn_output], dim = 1) gen_dist = self.gen_dist_trans(gen_dist_trans_input) gen_dist = torch.sigmoid(gen_dist) # combined_prob_word: [batch_size, num_vocab] combined_prob_word = prob_word * gen_dist + (1 - gen_dist) * copy_prob else: combined_prob_word = prob_word # logit_word : [batch_size, num_vocab] logit_word = combined_prob_word.log() return hidden, logit_word
def sparseMatmul(A: torch.Tensor, x: torch.Tensor, sizeA: Union[List[int], Tuple[int, int]], indexA=None) -> torch.Tensor: """ :param A: <Tensor t, 3> m*n的矩阵的稀疏表示 :param x: <Tensor n> :param sizeA: A的大小,即[m,n] :param indexA: 可选,是A[:, 0:2].to(torch.int64) :return: A*x(矩阵乘法)的结果 <Tensor m> """ indexA = indexA if indexA is not None else A[:, 0:2].to(torch.int64) item1 = A[:, 2] * x[indexA[:, 1]] return torch.scatter_add( torch.zeros(sizeA[0], dtype=x.dtype, device=x.device), 0, indexA[:, 0], item1)
def forward(self, x): if self.metric_type == MetricType.minkowski: k = torch.cdist(x, self.train_data, p=self.p, compute_mode="donot_use_mm_for_euclid_dist") elif self.metric_type == MetricType.wminkowski: k = torch.cdist(self.w * x, self.train_data, p=self.p, compute_mode="donot_use_mm_for_euclid_dist") elif self.metric_type == MetricType.seuclidean: k = torch.cdist(self.V * x, self.train_data, p=2, compute_mode="donot_use_mm_for_euclid_dist") elif self.metric_type == MetricType.mahalanobis: # We use the Cholesky decomposition to calculate the Mahalanobis distance # Mahalanobis distance d^2(x, x') = (x - x')T VI (x - x') # using Cholesky decomposition we have VI = LT L # then: # d^2(x, x') = (x - x')T (LT L) (x - x') # = (Lx - Lx')T (Lx - Lx') k = torch.cdist(torch.mm(x, self.L), self.train_data, p=2, compute_mode="donot_use_mm_for_euclid_dist") d, k = torch.topk(k, self.n_neighbors, dim=1, largest=False) output = torch.index_select(self.train_labels, 0, k.view(-1)).view(-1, self.n_neighbors) if self.weights == "distance": d = torch.pow(d, -1) inf_mask = torch.isinf(d) inf_row = torch.any(inf_mask, axis=1) d[inf_row] = inf_mask[inf_row].float() else: d = torch.ones_like(k, dtype=torch.float32) if self.classification: # classification output = torch.scatter_add(self.proba_tensor, 1, output, d) proba_sum = output.sum(1, keepdim=True) proba_sum = torch.where(proba_sum == 0, self.one_tensor, proba_sum) output = torch.pow(proba_sum, -1) * output if self.perform_class_select: return torch.index_select(self.classes, 0, torch.argmax(output, dim=1)), output else: return torch.argmax(output, dim=1), output else: # regression output = d * output if self.weights != "distance": output = output.sum(1) / self.n_neighbors else: denom = d.sum(1) output = output.sum(1) / denom return output
def scatter_add(device: Type[draw_devices], token_sizes: Type[draw_token_sizes], dim: Type[draw_embedding_dims], *, timer: TimerSuit): device = device() token_size, num = token_sizes(), token_sizes() if num > token_size: token_size, num = num, token_size in_dim = dim() inputs = torch.randn((token_size, in_dim), requires_grad=True, device=device) index1 = torch.randint(0, num, (token_size, ), device=device) index2 = index1[:, None].expand_as(inputs) with timer.rua_forward: actual = rua.scatter_add(tensor=inputs, index=index1) with timer.naive_forward: excepted = torch.scatter_add( torch.zeros((num, in_dim), device=device), src=inputs, index=index2, dim=0, ) with timer.rua_backward: torch.autograd.grad( actual, inputs, torch.ones_like(actual), create_graph=False, ) with timer.naive_backward: torch.autograd.grad( excepted, inputs, torch.ones_like(excepted), create_graph=False, )
def forward(self, input_ids, input_context, context_mask, decode_input, decode_target, use_beam_search, beam_width): """ :param input_ids: 用于解码增强的输入ids序列 :param input_context: 编码的context (bsz, enc_seq, dim) :param context_mask: 沿用encoder部分的input_mask, 将pad的输入忽略 :param decode_input: 解码输入 ==> 训练时才有 :param decode_target: 解码目标 ==> 训练时才有, 测试时为空 :param use_beam_search: 是否启动beam search解码 :param beam_width: beam宽度 :return: 训练时返回损失, 测试时返回解码序列 """ bsz = input_context.size(0) net_state = self.init_hidden_unit.repeat(bsz, 1) if decode_target is not None: dec_list = [] decode_emb = self.embedding_matrix( decode_input) # 作为输入的一部分(bsz, dec_seq, dim) for i in range(self.seq_len): # step1: 通过注意力机制获取当前的context_rep attn_score = torch.einsum("bsd,bd->bs", input_context, net_state) attn_score.mul_(self.scale) attn_score += (1.0 - context_mask) * (-1e30) attn_prob = torch.softmax(attn_score, dim=-1) attn_vec = torch.einsum("bs,bsd->bd", attn_prob, input_context) # step2: 更新状态 x = torch.cat([attn_vec, decode_emb[:, i, :], net_state], dim=-1) reset_sig = self.reset_gate(x) update_sig = self.update_gate(x) update_value = self.update( torch.cat( [attn_vec, decode_emb[:, i, :], reset_sig * net_state], dim=-1)) net_state = ( 1 - update_sig) * net_state + update_sig * update_value # step3: 计算分布概率--> mos vocab_prob_list = [] # pi_k = self.pi_mos(net_state) # for k in range(args["mos"]): # output = self.output[k](net_state) # vocab_logits = torch.nn.functional.linear(input=output, weight=self.embedding_matrix.weight) # vocab_prob_list.append(torch.softmax(vocab_logits, dim=-1)[..., None]) # vocab_prob = torch.einsum("bk,bvk->bv", pi_k, torch.cat(vocab_prob_list, dim=-1)) output = self.output(net_state) vocab_logits = torch.nn.functional.linear( input=output, weight=self.embedding_matrix.weight) vocab_prob = torch.softmax(vocab_logits, dim=-1) input_logits = torch.einsum("bd,bsd->bs", self.copy_output(net_state), input_context) input_logits += (1.0 - context_mask) * (-1e30) input_prob = torch.softmax(input_logits, dim=-1) # (bsz, enc_seq) # step4: 根据mode_sig混合两个概率 mode_sig = self.mode_select(net_state) vocab_prob = vocab_prob * mode_sig vocab_prob = torch.scatter_add(vocab_prob, dim=1, index=input_ids, src=input_prob * (1 - mode_sig)) dec_list.append(vocab_prob[:, None, :]) # 计算损失 predict = torch.cat(dec_list, dim=1) # (bsz, dec_seq, vocab) predict = predict.view(size=(-1, predict.size(-1))) decode_target = decode_target.view(size=(-1, )) predict = torch.gather(predict, dim=1, index=decode_target[:, None]).squeeze(dim=-1) init_loss = -torch.log(predict + self.epsilon) init_loss *= (decode_target != 0).float() loss = torch.sum(init_loss) / torch.nonzero(decode_target != 0, as_tuple=False).size(0) return loss[None].repeat(bsz) else: if use_beam_search: pass else: # 贪婪式解码 dec_list = [] for i in range(self.seq_len): # step1: 通过注意力机制获取当前的context_rep attn_score = torch.einsum("bsd,bd->bs", input_context, net_state) attn_score.mul_(self.scale) attn_score += (1.0 - context_mask) * (-1e30) attn_prob = torch.softmax(attn_score, dim=-1) attn_vec = torch.einsum("bs,bsd->bd", attn_prob, input_context) # step2: 更新状态 if i == 0: emb = self.embedding_matrix( torch.full(size=(bsz, ), fill_value=args["start_token_id"], dtype=torch.int32).long().to(device)) else: emb = self.embedding_matrix( dec_list[i - 1].squeeze(dim=-1)) x = torch.cat([attn_vec, emb, net_state], dim=-1) reset_sig = self.reset_gate(x) update_sig = self.update_gate(x) update_value = self.update( torch.cat([attn_vec, emb, reset_sig * net_state], dim=-1)) net_state = ( 1 - update_sig) * net_state + update_sig * update_value # step3: 计算分布得分 vocab_prob_list = [] # pi_k = self.pi_mos(net_state) # for k in range(args["mos"]): # output = self.output[k](net_state) # vocab_logits = torch.nn.functional.linear(input=output, weight=self.embedding_matrix.weight) # vocab_prob_list.append(torch.softmax(vocab_logits, dim=-1)[..., None]) # vocab_prob = torch.einsum("bk,bvk->bv", pi_k, torch.cat(vocab_prob_list, dim=-1)) output = self.output(net_state) vocab_logits = torch.nn.functional.linear( input=output, weight=self.embedding_matrix.weight) vocab_prob = torch.softmax(vocab_logits, dim=-1) input_logits = torch.einsum("bd,bsd->bs", self.copy_output(net_state), input_context) input_logits += (1.0 - context_mask) * (-1e30) input_prob = torch.softmax(input_logits, dim=-1) # (bsz, enc_seq) # step4: 根据mode_sig混合两个概率 mode_sig = self.mode_select(net_state) vocab_prob = vocab_prob * mode_sig vocab_prob = torch.scatter_add(vocab_prob, dim=1, index=input_ids, src=input_prob * (1 - mode_sig)) dec_list.append(torch.argmax(vocab_prob, dim=-1)[:, None]) return torch.cat(dec_list, dim=-1)
def _impl(x, dim, index, src): dim = dim.item() return (torch.scatter_add(x, dim, index, src), )
def forward(self, input_ids, encoder_rep, input_mask, decode_input, decode_target, use_beam_search, beam_width): bsz = input_ids.size(0) if decode_input is not None: # 代表训练模式 input_ids = input_ids[:, None, :].repeat(1, self.seq_len, 1) decode_embed = self.drop(self.word_emb(decode_input)) all_ones = decode_embed.new_ones((self.seq_len, self.seq_len), dtype=torch.uint8) dec_attn_mask = torch.tril(all_ones, diagonal=0)[:, :, None, None] pos_seq = torch.arange(self.seq_len - 1, -1, -1.0, device=device, dtype=decode_embed.dtype) pos_embed = self.drop(self.pos_emb(pos_seq)) core_out = decode_embed.transpose(0, 1).contiguous() enc_rep_t = encoder_rep.transpose(0, 1).contiguous() enc_mask_t = input_mask.transpose(0, 1).contiguous() for layer in self.layers: core_out = layer(dec_inp=core_out, r=pos_embed, enc_inp=enc_rep_t, dec_mask=dec_attn_mask, enc_mask=enc_mask_t) core_out = self.drop(core_out.transpose( 0, 1).contiguous()) # (bsz, dec_len, dim) output = self.output(core_out) vocab_logits = torch.nn.functional.linear( input=output, weight=self.word_emb.weight) vocab_prob = torch.softmax(vocab_logits, dim=-1) input_logits = torch.einsum("bid,bjd->bij", self.copy_output(core_out), encoder_rep) # (bsz, dec_len, enc_len) input_logits = input_logits + (1.0 - input_mask[:, None, :].repeat( 1, self.seq_len, 1)) * (-1e30) input_prob = torch.softmax(input_logits, dim=-1) # (bsz, dec_len, enc_len) mode_sig = self.mode_select(core_out) # (bsz, dec_len, 1) vocab_prob = vocab_prob * mode_sig vocab_prob = torch.scatter_add(vocab_prob, dim=2, index=input_ids, src=(1 - mode_sig) * input_prob) vocab_prob = vocab_prob.view(-1, args["vocab_size"]) decode_target = decode_target.view(-1) predict = torch.gather(vocab_prob, dim=1, index=decode_target[:, None]).squeeze(dim=-1) init_loss = -torch.log(predict + self.epsilon) init_loss *= (decode_target != 0).float() loss = torch.sum(init_loss) / torch.nonzero(decode_target != 0, as_tuple=False).size(0) # 为了并行化设计, 将loss变成(bsz,) return loss[None].repeat(bsz) else: # 代表验证或者测试解码模式 ==> 比较耗时 if not use_beam_search: # 使用贪心搜索 ==> 验证集 dec_list = [] decode_ids = torch.full(size=(bsz, 1), fill_value=args["start_token_id"], dtype=torch.int32).long().to(device) for i in range(1, self.seq_len + 1): if i > 1: decode_ids = torch.cat([decode_ids, dec_list[i - 2]], dim=-1) decode_embed = self.word_emb(decode_ids) all_ones = decode_embed.new_ones((i, i), dtype=torch.uint8) dec_attn_mask = torch.tril(all_ones, diagonal=0)[:, :, None, None] pos_seq = torch.arange(i - 1, -1, -1.0, device=device, dtype=decode_embed.dtype) pos_embed = self.pos_emb(pos_seq) core_out = decode_embed.transpose(0, 1).contiguous() enc_rep_t = encoder_rep.transpose(0, 1).contiguous() enc_mask_t = input_mask.transpose(0, 1).contiguous() for layer in self.layers: core_out = layer(dec_inp=core_out, r=pos_embed, enc_inp=enc_rep_t, dec_mask=dec_attn_mask, enc_mask=enc_mask_t) core_out = core_out.transpose(0, 1).contiguous()[:, -1, :] output = self.output(core_out) vocab_logits = torch.nn.functional.linear( input=output, weight=self.word_emb.weight) vocab_prob = torch.softmax(vocab_logits, dim=-1) input_logits = torch.einsum("bd,bjd->bj", self.copy_output(core_out), encoder_rep) # (bsz, enc_len) input_logits = input_logits + (1.0 - input_mask) * (-1e30) input_prob = torch.softmax(input_logits, dim=-1) # (bsz, enc_len) mode_sig = self.mode_select(core_out) # (bsz, 1) vocab_prob = vocab_prob * mode_sig vocab_prob = torch.scatter_add(vocab_prob, dim=1, index=input_ids, src=(1 - mode_sig) * input_prob) dec_list.append(torch.argmax(vocab_prob, dim=-1)[:, None]) return torch.cat(dec_list, dim=-1) else: # 使用集束搜索 # 扩展成beam_width * bsz """ 需要注意: 1. trigram-block的使用 ==> 出现重复直接加上-1e9(需要考虑end_token边界=>只在边界范围内使用) 2. 长度惩罚, 考虑end_token边界 """ decode_ids = torch.full(size=(bsz * beam_width, 1), fill_value=args["start_token_id"], dtype=torch.int32).long().to(device) input_ids = input_ids.repeat((beam_width, 1)) encoder_rep = encoder_rep.repeat((beam_width, 1, 1)) input_mask = input_mask.repeat((beam_width, 1)) dec_topK_log_probs = [0] * (beam_width * bsz ) # (bsz*beam) 每个序列的当前log概率和 dec_topK_sequences = [[] for _ in range(beam_width * bsz) ] # (bsz*beam, seq_len) 解码id序列 dec_topK_seq_lens = [1] * ( beam_width * bsz ) # 解码序列长度 ==> 加上一个偏置项1, 防止进行长度惩罚时出现div 0的情况 for i in range(1, self.seq_len + 1): if i > 1: input_decode_ids = torch.cat([ decode_ids, torch.tensor(dec_topK_sequences).long().to(device) ], dim=-1) else: input_decode_ids = decode_ids decode_embed = self.word_emb(input_decode_ids) all_ones = decode_embed.new_ones((i, i), dtype=torch.uint8) dec_attn_mask = torch.tril(all_ones, diagonal=0)[:, :, None, None] pos_seq = torch.arange(i - 1, -1, -1.0, device=device, dtype=decode_embed.dtype) pos_embed = self.pos_emb(pos_seq) core_out = decode_embed.transpose(0, 1).contiguous() enc_rep_t = encoder_rep.transpose(0, 1).contiguous() enc_mask_t = input_mask.transpose(0, 1).contiguous() for layer in self.layers: core_out = layer(dec_inp=core_out, r=pos_embed, enc_inp=enc_rep_t, dec_mask=dec_attn_mask, enc_mask=enc_mask_t) core_out = core_out.transpose(0, 1).contiguous()[:, -1, :] output = self.output(core_out) vocab_logits = torch.nn.functional.linear( input=output, weight=self.word_emb.weight) vocab_prob = torch.softmax(vocab_logits, dim=-1) input_logits = torch.einsum( "bd,bjd->bj", self.copy_output(core_out), encoder_rep) # (bsz*beam, enc_len) input_logits = input_logits + (1.0 - input_mask) * (-1e30) input_prob = torch.softmax(input_logits, dim=-1) # (bsz*beam, enc_len) mode_sig = self.mode_select(core_out) # (bsz*beam, 1) vocab_prob = vocab_prob * mode_sig vocab_prob = torch.scatter_add( vocab_prob, dim=1, index=input_ids, src=(1 - mode_sig) * input_prob) # (bsz*beam, vocab) vocab_logp = torch.log(vocab_prob + self.epsilon) # 取对数, 加eps """ step1: 检查是否存在trigram blocking重叠, 只需要检查最后一项和之前项是否存在重叠即可 """ if i > 4: # 当序列长度大于等于4时才有意义, 或者当前解码时刻大于4时才有检查的必要 for j in range(beam_width * bsz): trigram_blocks = [] for k in range(3, i): if dec_topK_sequences[j][ k - 1] == args["end_token_id"]: break trigram_blocks.append( dec_topK_sequences[j][k - 3:k]) if len(trigram_blocks) > 1 and trigram_blocks[ -1] in trigram_blocks[:-1]: dec_topK_log_probs[j] += -1e9 """ step2: 为每个样本, 选择topK个序列 ==> 类似于重构dec_topK_sequences""" for j in range(bsz): topK_vocab_logp = vocab_logp[j::bsz] # (k, vocab) candidate_list = [] """ 容易出错的地方, i=1的时候不需要为每个K生成K个候选,否则beam search将完全沦为greedy search """ for k in range(beam_width): ind = j + k * bsz if args["end_token_id"] in dec_topK_sequences[ind]: candidate_list.append({ "add_logit": 0, "add_seq_len": 0, "affiliate_k": k, "add_token_id": args["end_token_id"], "sort_logits": dec_topK_log_probs[ind] / (dec_topK_seq_lens[ind]** args["beam_length_penalty"]) }) else: k_logps, k_indices = topK_vocab_logp[k].topk( dim=0, k=beam_width) k_logps, k_indices = k_logps.cpu().numpy( ), k_indices.cpu().numpy() for l in range(beam_width): aff = l if i == 1 else k candidate_list.append({ "add_logit": k_logps[l], "add_seq_len": 1, "affiliate_k": aff, "add_token_id": k_indices[l], "sort_logits": (dec_topK_log_probs[ind] + k_logps[l]) / ((dec_topK_seq_lens[ind] + 1)** args["beam_length_penalty"]) }) if i == 1: ## 当解码第一个词的时候只能考虑一个 break candidate_list.sort(key=lambda x: x["sort_logits"], reverse=True) candidate_list = candidate_list[:beam_width] """ 序列修正, 更新topK """ c_dec_topK_sequences, c_dec_topK_log_probs, c_dec_topK_seq_lens = \ deepcopy(dec_topK_sequences), deepcopy(dec_topK_log_probs), deepcopy(dec_topK_seq_lens) for k in range(beam_width): ind = bsz * candidate_list[k]["affiliate_k"] + j r_ind = bsz * k + j father_seq, father_logits, father_len = c_dec_topK_sequences[ ind], c_dec_topK_log_probs[ ind], c_dec_topK_seq_lens[ind] dec_topK_sequences[r_ind] = father_seq + [ candidate_list[k]["add_token_id"] ] dec_topK_log_probs[ r_ind] = father_logits + candidate_list[k][ "add_logit"] dec_topK_seq_lens[ r_ind] = father_len + candidate_list[k][ "add_seq_len"] return torch.tensor(dec_topK_sequences[:bsz]).long().to(device)
def map_edge_lists(edge_lists: list, perform_unique=True, known_node_ids=None, sequential_train_nodes=False, sequential_deg_nodes=0): print("Remapping Edges") defined_edges = [] for edge_list in edge_lists: if edge_list is not None: defined_edges.append(edge_list) edge_lists = defined_edges if isinstance(edge_lists[0], pd.DataFrame): if isinstance(edge_lists[0].iloc[0][0], str): # need to take uniques using pandas for string datatypes, since torch doesn't support strings return map_edge_list_dfs(edge_lists, known_node_ids, sequential_train_nodes, sequential_deg_nodes) new_edge_lists = [] for edge_list in edge_lists: new_edge_lists.append(dataframe_to_tensor(edge_list)) edge_lists = new_edge_lists all_edges = torch.cat(edge_lists) has_rels = False num_rels = 1 unique_rels = torch.empty([0]) mapped_rel_ids = torch.empty([0]) if all_edges.size(1) == 3: has_rels = True output_dtype = torch.int32 if perform_unique: unique_src = torch.unique(all_edges[:, 0]) unique_dst = torch.unique(all_edges[:, -1]) if known_node_ids is None: unique_nodes = torch.unique(torch.cat([unique_src, unique_dst]), sorted=True) else: unique_nodes = torch.unique(torch.cat([unique_src, unique_dst] + known_node_ids), sorted=True) num_nodes = unique_nodes.size(0) if has_rels: unique_rels = torch.unique(all_edges[:, 1], sorted=True) num_rels = unique_rels.size(0) else: num_nodes = torch.max(all_edges[:, 0])[0] unique_nodes = torch.arange(num_nodes).to(output_dtype) if has_rels: num_rels = torch.max(all_edges[:, 1])[0] unique_rels = torch.arange(num_rels).to(output_dtype) if sequential_train_nodes or sequential_deg_nodes > 0: seq_nodes = None if sequential_train_nodes and sequential_deg_nodes <= 0: print("Sequential Train Nodes") seq_nodes = known_node_ids[0] else: out_degrees = torch.zeros([ num_nodes, ], dtype=torch.int32) out_degrees = torch.scatter_add( out_degrees, 0, torch.squeeze(edge_lists[0][:, 0]).to(torch.int64), torch.ones([ edge_lists[0].shape[0], ], dtype=torch.int32)) in_degrees = torch.zeros([ num_nodes, ], dtype=torch.int32) in_degrees = torch.scatter_add( in_degrees, 0, torch.squeeze(edge_lists[0][:, -1]).to(torch.int64), torch.ones([ edge_lists[0].shape[0], ], dtype=torch.int32)) degrees = in_degrees + out_degrees deg_argsort = torch.argsort(degrees, dim=0, descending=True) high_degree_nodes = deg_argsort[:sequential_deg_nodes] print("High Deg Nodes Degree Sum:", torch.sum(degrees[high_degree_nodes]).numpy()) if sequential_train_nodes and sequential_deg_nodes > 0: print("Sequential Train and High Deg Nodes") seq_nodes = torch.unique( torch.cat([high_degree_nodes, known_node_ids[0]])) seq_nodes = seq_nodes.index_select( 0, torch.randperm(seq_nodes.size(0), dtype=torch.int64)) print("Total Seq Nodes: ", seq_nodes.shape[0]) else: print("Sequential High Deg Nodes") seq_nodes = high_degree_nodes seq_mask = torch.zeros(num_nodes, dtype=torch.bool) seq_mask[seq_nodes.to(torch.int64)] = True all_other_nodes = torch.arange(num_nodes, dtype=seq_nodes.dtype) all_other_nodes = all_other_nodes[~seq_mask] mapped_node_ids = -1 * torch.ones(num_nodes, dtype=output_dtype) mapped_node_ids[seq_nodes.to(torch.int64)] = torch.arange( seq_nodes.shape[0], dtype=output_dtype) mapped_node_ids[all_other_nodes.to( torch.int64)] = seq_nodes.shape[0] + torch.randperm( num_nodes - seq_nodes.shape[0], dtype=output_dtype) else: mapped_node_ids = torch.randperm(num_nodes, dtype=output_dtype) if has_rels: mapped_rel_ids = torch.randperm(num_rels, dtype=output_dtype) # TODO may use too much memory if the max id is very large # Needed to support indexing w/ the remap if torch.max(unique_nodes) + 1 > num_nodes: extended_map = torch.zeros(torch.max(unique_nodes) + 1, dtype=output_dtype) extended_map[unique_nodes] = mapped_node_ids else: extended_map = mapped_node_ids all_edges = None # can safely free this tensor output_edge_lists = [] for edge_list in edge_lists: new_src = extended_map[edge_list[:, 0].to(torch.int64)] new_dst = extended_map[edge_list[:, -1].to(torch.int64)] if has_rels: new_rel = mapped_rel_ids[edge_list[:, 1].to(torch.int64)] output_edge_lists.append( torch.stack([new_src, new_rel, new_dst], dim=1)) else: output_edge_lists.append(torch.stack([new_src, new_dst], dim=1)) node_mapping = np.stack( [unique_nodes.numpy(), mapped_node_ids.numpy()], axis=1) rel_mapping = None if has_rels: rel_mapping = np.stack( [unique_rels.numpy(), mapped_rel_ids.numpy()], axis=1) return output_edge_lists, node_mapping, rel_mapping
def forward(self, input_ids, encoder_rep, input_mask, decode_input, decode_target, use_beam_search, beam_width): bsz = input_ids.size(0) if decode_input is not None: # 代表训练模式 input_ids = input_ids[:, None, :].repeat(1, self.seq_len, 1) decode_embed = self.drop(self.word_emb(decode_input)) all_ones = decode_embed.new_ones((self.seq_len, self.seq_len), dtype=torch.uint8) dec_attn_mask = torch.tril(all_ones, diagonal=0)[:, :, None, None] pos_seq = torch.arange(self.seq_len - 1, -1, -1.0, device=device, dtype=decode_embed.dtype) pos_embed = self.drop(self.pos_emb(pos_seq)) core_out = decode_embed.transpose(0, 1).contiguous() enc_rep_t = encoder_rep.transpose(0, 1).contiguous() enc_mask_t = input_mask.transpose(0, 1).contiguous() for layer in self.layers: core_out = layer(dec_inp=core_out, r=pos_embed, enc_inp=enc_rep_t, dec_mask=dec_attn_mask, enc_mask=enc_mask_t) core_out = self.drop(core_out.transpose( 0, 1).contiguous()) # (bsz, dec_len, dim) output = self.output(core_out) vocab_logits = torch.nn.functional.linear( input=output, weight=self.word_emb.weight) vocab_prob = torch.softmax(vocab_logits, dim=-1) input_logits = torch.einsum("bid,bjd->bij", self.copy_output(core_out), encoder_rep) # (bsz, dec_len, enc_len) input_logits = input_logits + (1.0 - input_mask[:, None, :].repeat( 1, self.seq_len, 1)) * (-1e30) input_prob = torch.softmax(input_logits, dim=-1) # (bsz, dec_len, enc_len) mode_sig = self.mode_select(core_out) # (bsz, dec_len, 1) vocab_prob = vocab_prob * mode_sig vocab_prob = torch.scatter_add(vocab_prob, dim=2, index=input_ids, src=(1 - mode_sig) * input_prob) vocab_prob = vocab_prob.view(-1, args["vocab_size"]) decode_target = decode_target.view(-1) predict = torch.gather(vocab_prob, dim=1, index=decode_target[:, None]).squeeze(dim=-1) init_loss = -torch.log(predict + self.epsilon) init_loss *= (decode_target != 0).float() loss = torch.sum(init_loss) / torch.nonzero(decode_target != 0, as_tuple=False).size(0) # 为了并行化设计, 将loss变成(bsz,) return loss[None].repeat(bsz) else: # 代表验证或者测试解码模式 ==> 比较耗时 dec_list = [] decode_ids = torch.full(size=(bsz, 1), fill_value=args["start_token_id"], dtype=torch.int32).long().to(device) for i in range(1, self.seq_len + 1): if i > 1: decode_ids = torch.cat([decode_ids, dec_list[i - 2]], dim=-1) decode_embed = self.word_emb(decode_ids) all_ones = decode_embed.new_ones((i, i), dtype=torch.uint8) dec_attn_mask = torch.tril(all_ones, diagonal=0)[:, :, None, None] pos_seq = torch.arange(i - 1, -1, -1.0, device=device, dtype=decode_embed.dtype) pos_embed = self.pos_emb(pos_seq) core_out = decode_embed.transpose(0, 1).contiguous() enc_rep_t = encoder_rep.transpose(0, 1).contiguous() enc_mask_t = input_mask.transpose(0, 1).contiguous() for layer in self.layers: core_out = layer(dec_inp=core_out, r=pos_embed, enc_inp=enc_rep_t, dec_mask=dec_attn_mask, enc_mask=enc_mask_t) core_out = core_out.transpose(0, 1).contiguous()[:, -1, :] output = self.output(core_out) vocab_logits = torch.nn.functional.linear( input=output, weight=self.word_emb.weight) vocab_prob = torch.softmax(vocab_logits, dim=-1) input_logits = torch.einsum("bd,bjd->bj", self.copy_output(core_out), encoder_rep) # (bsz, enc_len) input_logits = input_logits + (1.0 - input_mask) * (-1e30) input_prob = torch.softmax(input_logits, dim=-1) # (bsz, enc_len) mode_sig = self.mode_select(core_out) # (bsz, 1) vocab_prob = vocab_prob * mode_sig vocab_prob = torch.scatter_add(vocab_prob, dim=1, index=input_ids, src=(1 - mode_sig) * input_prob) dec_list.append(torch.argmax(vocab_prob, dim=-1)[:, None]) return torch.cat(dec_list, dim=-1)
def model(x, y, alt_av, alt_ids_cuda): # global parameters in the model if diagonal_alpha: alpha_mu = pyro.sample( "alpha", dist.Normal(torch.zeros(len(non_mix_params), device=x.device), 1).to_event(1)) else: alpha_mu = pyro.sample( "alpha", dist.MultivariateNormal( torch.zeros(len(non_mix_params), device=x.device), scale_tril=torch.tril( 1 * torch.eye(len(non_mix_params), device=x.device)))) if diagonal_beta_mu: beta_mu = pyro.sample( "beta_mu", dist.Normal(torch.zeros(len(mix_params), device=x.device), 1.).to_event(1)) else: beta_mu = pyro.sample( "beta_mu", dist.MultivariateNormal( torch.zeros(len(mix_params), device=x.device), scale_tril=torch.tril( 1 * torch.eye(len(mix_params), device=x.device)))) # Vector of variances for each of the d variables theta = pyro.sample( "theta", dist.HalfCauchy( 10. * torch.ones(len(mix_params), device=x.device)).to_event(1)) # Lower cholesky factor of a correlation matrix eta = 1. * torch.ones( 1, device=x.device ) # Implies a uniform distribution over correlation matrices L_omega = pyro.sample("L_omega", dist.LKJCorrCholesky(len(mix_params), eta)) # Lower cholesky factor of the covariance matrix L_Omega = torch.mm(torch.diag(theta.sqrt()), L_omega) # local parameters in the model random_params = pyro.sample( "beta_resp", dist.MultivariateNormal(beta_mu.repeat(num_resp, 1), scale_tril=L_Omega).to_event(1)) # vector of respondent parameters: global + local (respondent) params_resp = torch.cat([alpha_mu.repeat(num_resp, 1), random_params], dim=-1) # vector of betas of MXL (may repeat the same learnable parameter multiple times; random + fixed effects) beta_resp = torch.cat([ params_resp[:, beta_to_params_map[i]] for i in range(num_alternatives) ], dim=-1) with pyro.plate("locals", len(x), subsample_size=BATCH_SIZE) as ind: with pyro.plate("data_resp", T): # compute utilities for each alternative utilities = torch.scatter_add( zeros_vec[:, ind, :], 2, alt_ids_cuda[ind, :, :].transpose(0, 1), torch.mul(x[ind, :, :].transpose(0, 1), beta_resp[ind, :])) # adjust utility for unavailable alternatives utilities += alt_av[ind, :, :].transpose(0, 1) # likelihood pyro.sample("obs", dist.Categorical(logits=utilities), obs=y[ind, :].transpose(0, 1))
def backward(ctx, grad_output): featureH, featureL, label, centers, margins, gamma = ctx.saved_tensors num_pair = featureH.shape[0] num_class = centers.shape[0] centers_normed = F.normalize(centers, dim=1) # Cx1024 featureH_normed = F.normalize(featureH, dim=1) # N'x1024 distmatH = torch.mm(featureH_normed, centers_normed.t()) # N'xC featureL_normed = F.normalize(featureL, dim=1) distmatL = torch.mm(featureL_normed, centers_normed.t()) mask = distmatH.new_zeros(distmatH.size(), dtype=torch.long) mask.scatter_add_( 1, label.long().unsqueeze(1), torch.ones(num_pair, 1, device=mask.device, dtype=torch.long)) distHic = distmatH[mask == 1] distLic = distmatL[mask == 1] distHicL, hard_index_batch = torch.max(distmatH[mask == 0].view( num_pair, num_class - 1), dim=1) hard_index_batch[hard_index_batch >= label] += 1 li1 = torch.acos(distHic) - torch.acos(distLic) + margins[0] li2 = torch.acos(distHic) - torch.acos(distHicL) + margins[1] centers_normed_batch = centers_normed.index_select(0, label.long()) hard_normed_batch = centers_normed.index_select(0, hard_index_batch) d = -(1 - distHic.pow(2)).pow(-0.5) e = -(1 - distHicL.pow(2)).pow(-0.5) f = -(1 - distLic.pow(2)).pow(-0.5) I = torch.eye(featureH.shape[1], device=d.device) xcH = (I - torch.einsum('bi,bj->bij', (featureH_normed, featureH_normed)) ) / featureH.norm(dim=1, keepdim=True).unsqueeze(-1) xcL = (I - torch.einsum('bi,bj->bij', (featureL_normed, featureL_normed)) ) / featureL.norm(dim=1, keepdim=True).unsqueeze(-1) cc = (I - torch.einsum('bi,bj->bij', (centers_normed, centers_normed)) ) / centers.norm(dim=1, keepdim=True).unsqueeze(-1) d = d.unsqueeze(1) e = e.unsqueeze(1) f = f.unsqueeze(1) counts_h = centers.new_ones(num_class) # (C,) counts_hl = centers.new_ones(num_class) # (C,) counts_c = centers.new_ones(num_class) # (C,) ones_h = centers.new_ones(num_pair) # (N',) ones_h[li2 <= 0] = 0 ones_c = centers.new_ones(num_pair) # (N',) ones_c[li1 <= 0] = 0 grad_centers = centers.new_zeros(centers.size()) # Cx1024 counts_h.scatter_add_(0, label.long(), ones_h) counts_hl.scatter_add_(0, hard_index_batch, ones_h) counts_c.scatter_add_(0, label.long(), ones_c) grad_centers_h = featureH_normed * d grad_centers_h[li2 <= 0] = 0 grad_centers += torch.scatter_add( centers.new_zeros(centers.size()), 0, label.unsqueeze(1).expand(featureH_normed.size()).long(), grad_centers_h) / counts_h.unsqueeze(-1) grad_centers_hl = -featureH_normed * e grad_centers_hl[li2 <= 0] = 0 grad_centers += torch.scatter_add( centers.new_zeros(centers.size()), 0, hard_index_batch.unsqueeze(1).expand(featureH_normed.size()), grad_centers_hl) / counts_hl.unsqueeze(-1) grad_centers_c = featureH_normed * d - featureL_normed * f grad_centers_c[li1 <= 0] = 0 grad_centers += torch.scatter_add( centers.new_zeros(centers.size()), 0, label.unsqueeze(1).expand(featureH_normed.size()).long(), grad_centers_c * gamma) / counts_c.unsqueeze(-1) grad_centers /= num_pair grad = centers_normed_batch * d - hard_normed_batch * e grad[li2 <= 0] = 0 grad_h = centers_normed_batch * d grad_h[li1 <= 0] = 0 grad_h = grad_output * (grad + grad_h * gamma) / num_pair grad_l = -centers_normed_batch * f grad_l[li1 <= 0] = 0 grad_l = grad_output * grad_l * gamma / num_pair return torch.bmm(xcH, grad_h.unsqueeze(-1)).squeeze(-1), torch.bmm( xcL, grad_l.unsqueeze(-1)).squeeze(-1), None, torch.bmm( cc, grad_centers.unsqueeze(-1)).squeeze(-1), None, None
def forward(self, src_context_output, src_word_output, src_word_len, KG_word_output, KG_word_len, KG_word_seq, tgt_word_input, combine_knowledge = False): """Compute decoder scores from context_output. Args: src_context_output (FloatTensor) : (batch_size, context_size) src_word_output (FloatTensor) : (batch_size, src_max_word_len, word_size) src_word_len (LongTensor) : (batch_size) KG_word_output (FloatTensor) : (batch_size, KG_max_word_len, KG_word_size) KG_word_len (LongTensor) : (batch_size) KG_word_seq (LongTensor) : (batch_size, src_max_word_len) tgt_word_input (LongTensor) : (batch_size, tgt_word_len) Regurns: logit (FloatTensor) : (batch_size, tgt_word_len, num_vocab) converage (FloatTensor) : (batch_size, tgt_word_len) """ assert src_context_output.size(0) == src_word_output.size(0) assert src_context_output.size(0) == tgt_word_input.size(0) assert src_context_output.size(0) == KG_word_output.size(0) batch_size = src_context_output.size(0) max_src_len = src_word_output.size(1) max_KG_len = KG_word_output.size(1) max_tgt_len = tgt_word_input.size(1) # prepare the source word and KG word outputs # src_word_output : [src_word_len, batch_size, src_word_size] src_word_output = src_word_output.permute(1, 0, 2) # src_word_mask : [max_src_len, batch_size] src_word_mask = generate_mask_by_length(src_word_len, max_src_len) # KG_word_output : [KG_word_len, batch_size, KG_word_size] KG_word_output = KG_word_output.permute(1, 0, 2) # KG_word_mask : [max_KG_len, batch_size] KG_word_mask = generate_mask_by_length(KG_word_len, max_KG_len) # obtain word embedding and initial hidden states # tgt_word_emb : [batch_size, tgt_word_len, emb_size] tgt_word_emb = self.word_embedding(tgt_word_input) # hidden : [batch_size, num_layer, rnn_size] hidden = self.context2hidden(src_context_output).view(batch_size, self.num_layers, self.rnn_size) # hidden : [num_layer, batch_size, rnn_size] hidden = hidden.permute(1, 0, 2) logit_word_list = [] coverage_list = [] for word_index in range(max_tgt_len): # recurrence # last_hidden : [batch_size, rnn_size] last_hidden = hidden[-1] # attn_src_word : [batch_size, src_word_size] attn_src_word,_ = self.attention(last_hidden, src_word_output, src_word_output, src_word_mask) # attn_KG_word : [batch_size, src_KG_size] # attn_KG_scores : [max_KG_len ,batch_size] KG_attention_query = last_hidden + attn_src_word attn_KG_word, attn_KG_scores = self.attention(KG_attention_query, KG_word_output, KG_word_output, KG_word_mask) # rnn_inputs : [barch_size, emb_size + src_word_size + KG_word_size] rnn_inputs = torch.cat([tgt_word_emb[:,word_index], attn_src_word, attn_KG_word], dim = 1) # rnn_output : [batch_size, rnn_size] ; hidden : [num_layer, batch_size, rnn_size] rnn_output, hidden = self.rnn_cell(rnn_inputs, hidden) # prob_word : [batch_size, num_vocab] prob_word = self.output(rnn_output) if combine_knowledge == True: # copy_dist : [batch_size, num_vocab] copy_prob = torch.zeros(batch_size, self.num_vocab, device = src_context_output.device) copy_prob = torch.scatter_add(input = copy_prob, dim = 1, index = KG_word_seq, src = attn_KG_scores.permute(1, 0)) # gen_dist_trans_input : [batch_size, emb_size + KG_rnn_size + rnn_size] gen_dist_trans_input = torch.cat([tgt_word_emb[:, word_index], attn_KG_word, rnn_output], dim = 1) gen_dist = self.gen_dist_trans(gen_dist_trans_input) gen_dist = torch.sigmoid(gen_dist) # combined_prob_word: [batch_size, num_vocab] combined_prob_word = prob_word * gen_dist + (1 - gen_dist) * copy_prob else: combined_prob_word = prob_word # logit_word : [batch_size, num_vocab] logit_word = combined_prob_word.log() # coverage_score : [batch_size] coverage_score = cos_sim(last_hidden, hidden[-1]) coverage_score = F.relu(coverage_score) logit_word_list.append(logit_word) # logit : [batch_size, max_tgt_len, num_vocab] logit = torch.stack(logit_word_list, dim = 1) return logit
def chamfer_loss(xyz1, xyz2, batch1=None, batch2=None, reduce='mean'): """ Calculates the Chamfer distance between two batches of point clouds. :param xyz1: a point cloud of shape ``(b, n1, k)`` or ``(bn1, k)``. :param xyz2: a point cloud of shape (b, n2, k) or (bn2, k). :param batch1: batch indicator of shape ``(bn1,)`` for each point in `xyz1`. If ``None``, all points are assumed to be in the same point cloud. For batched point cloud, this ``None`` must be passed. :param batch2: batch indicator of shape ``(bn2,)`` for each point in `xyz2`. If ``None``, all points are assumed to be in the same point cloud. For batched point cloud, this ``None`` must be passed. :param reduce: ``'mean'`` or ``'sum'``. Default: ``'mean'``. :return: the Chamfer distance between the inputs. """ assert xyz1.ndim == xyz2.ndim, 'two point clouds do not have the same number of dimensions' assert len(xyz1.shape) in (2, 3) and len( xyz2.shape) in (2, 3), 'unknown shape of tensors' assert xyz1.shape[-1] == xyz2.shape[-1], 'mismatched feature dimension' assert reduce in ('mean', 'sum'), 'Unknown reduce method' if xyz1.dim() == 3: assert batch1 is None and batch2 is None, 'batch indicators must be None when point clouds are 3D tensors' assert xyz1.shape[0] == xyz2.shape[0], 'mismatched batch dimension' batch_idx = T.arange(xyz1.shape[0], device=xyz1.device) batch_idx = batch_idx[:, None] batch1 = T.zeros(*xyz1.shape[:-1], device=xyz1.device, dtype=T.long) batch1 += batch_idx batch1 = batch1.flatten() batch2 = T.zeros(*xyz2.shape[:-1], device=xyz2.device, dtype=T.long) batch2 += batch_idx batch2 = batch2.flatten() xyz1 = xyz1.view(-1, 3) xyz2 = xyz2.view(-1, 3) dist1, dist2, deg1, deg2 = stack_chamfer_distance(xyz1, xyz2, batch1, batch2) if batch1 is None: loss_2 = T.mean(dist2) loss_1 = T.mean(dist1) else: loss_1 = T.zeros_like(deg1).to(dist1.dtype) loss_1 = T.scatter_add(loss_1, 0, batch1, dist1) loss_1 = loss_1 / deg1 loss_2 = T.zeros_like(deg2).to(dist2.dtype) loss_2 = T.scatter_add(loss_2, 0, batch2, dist2) loss_2 = loss_2 / deg2 reduce = T.sum if reduce == 'sum' else T.mean loss_1 = reduce(loss_1) loss_2 = reduce(loss_2) return loss_1 + loss_2
def forward(self, src_graph_vecs, graph_batch, graph_tensors, init_atoms, orders): batch_size = len(orders) hgraph = HTuple( node=src_graph_vecs.new_zeros(graph_tensors[0].size(0), self.hidden_size), mess=self.rnn_cell.get_init_state(graph_tensors[1]), vmask=self.itensor.new_zeros(graph_tensors[0].size(0)), emask=self.itensor.new_zeros(graph_tensors[1].size(0)) ) # We assume that there is no edge directly connecting two initial subsgraphs all_topo_preds, all_atom_preds, all_bond_preds = [], [], [] graph_tensors = self.mpn.embed_graph(graph_tensors) + (graph_tensors[-1],) # preprocess graph tensors visited = set() new_atoms = [a for alist in init_atoms for a in alist] self.update_graph_mask(graph_batch, new_atoms, hgraph, visited) maxt = max([len(x) for x in orders]) for t in range(maxt): batch_list = [i for i in range(batch_size) if t < len(orders[i])] assert hgraph.vmask[0].item() == 0 and hgraph.emask[0].item() == 0 cur_graph_tensors = self.apply_graph_mask(graph_tensors, hgraph) vmask = hgraph.vmask.unsqueeze(-1).float() hgraph.node, hgraph.mess = self.mpn.encoder(*cur_graph_tensors[:-1], mask=vmask) new_atoms = [] for i in batch_list: xid, yid, front, fbond = orders[i][t] st, le = graph_tensors[-1][i] stop = 1 if yid is None else 0 gvec = hgraph.node[st: st + le].sum(dim=0) cxt_vec = torch.cat((gvec, hgraph.node[xid]), dim=-1) all_topo_preds.append((cxt_vec, i, stop)) # print('xvec', hgraph.mess[cur_graph_tensors[2][xid]].sum(dim=-1)) # print('xvec', cur_graph_tensors[1][cur_graph_tensors[2][xid]].nonzero()) if stop == 0: new_atoms.append(yid) ylabel = graph_batch.nodes[yid]['label'] atom_type = self.avocab[ylabel] all_atom_preds.append((cxt_vec, i, atom_type)) hist = torch.zeros_like(hgraph.node[xid]) # avoid inplace operation atom_vec = self.E_a[atom_type] assert front[0] == xid for zid, bt in zip(front, fbond): cur_hnode = torch.cat([hist, atom_vec], dim=-1) cur_hnode = self.R_bond(cur_hnode) pairs = torch.cat([gvec, cur_hnode, hgraph.node[zid]], dim=-1) all_bond_preds.append((pairs, i, bt)) if bt > 0: bond_vec = self.E_b[bt] z_hnode = torch.cat([hgraph.node[zid], bond_vec], dim=-1) hist += self.W_bond(z_hnode) self.update_graph_mask(graph_batch, new_atoms, hgraph, visited) topo_vecs, batch_idx, topo_labels = zip_tensors(all_topo_preds) topo_scores = self.get_topo_score(src_graph_vecs, batch_idx, topo_vecs) topo_loss = self.topo_loss(topo_scores, topo_labels.float()) topo_acc = get_accuracy_bin(topo_scores, topo_labels) topo_loss_batch = torch.zeros(batch_size).cuda() # for idx, bid in enumerate(batch_idx): # topo_loss_batch[bid] += topo_loss[idx] topo_loss_batch = torch.scatter_add(topo_loss_batch, dim=0, index=batch_idx, src=topo_loss) # print('topo_loss_batch', topo_loss_batch, topo_loss_batch.shape) # print(topo_scores) # print(topo_labels) atom_vecs, batch_idx, atom_labels = zip_tensors(all_atom_preds) atom_scores = self.get_atom_score(src_graph_vecs, batch_idx, atom_vecs) atom_loss = self.atom_loss(atom_scores, atom_labels) atom_acc = get_accuracy(atom_scores, atom_labels) atom_loss_batch = torch.zeros(batch_size).cuda() atom_loss_batch = torch.scatter_add(atom_loss_batch, dim=0, index=batch_idx, src=atom_loss) # for idx, bid in enumerate(batch_idx): # atom_loss_batch[bid] += atom_loss[idx] # print('atom_loss_batch', atom_loss_batch, atom_loss_batch.shape) # print(atom_scores.max(dim=-1)) # print(atom_labels) bond_vecs, batch_idx, bond_labels = zip_tensors(all_bond_preds) bond_scores = self.get_bond_score(src_graph_vecs, batch_idx, bond_vecs) bond_loss = self.bond_loss(bond_scores, bond_labels) bond_acc = get_accuracy(bond_scores, bond_labels) bond_loss_batch = torch.zeros(batch_size).cuda() # for idx, bid in enumerate(batch_idx): # bond_loss_batch[bid] += bond_loss[idx] # bond_loss_batch_v2 = torch.zeros(batch_size).cuda() bond_loss_batch = torch.scatter_add(bond_loss_batch, dim=0, index=batch_idx, src=bond_loss) # assert torch.sum(torch.abs(bond_loss_batch_v2 - bond_loss_batch)) < 1e-6 # bond_loss_batch[batch_idx] += bond_loss # print('bond_loss_batch', bond_loss_batch, bond_loss_batch.shape) # print(bond_labels) loss = topo_loss_batch + atom_loss_batch + bond_loss_batch # print(atom_loss.sum(), topo_loss.sum(), bond_loss.sum()) # print(atom_loss_batch.sum(), topo_loss_batch.sum(), bond_loss_batch.sum()) # loss = atom_loss + topo_loss + bond_loss return loss, atom_acc, topo_acc, bond_acc