Example #1
0
 def forward(self, feats, edges=None):
     # allocate memory
     dtype, device = feats.dtype, feats.device
     edges = edges.view(-1, 3)
     V, E = feats.size(0), edges.size(0)
     pooled_v_pos = torch.zeros(V, feats.shape[-3], feats.shape[-1], feats.shape[-1], dtype=dtype, device=device)
     pooled_v_neg = torch.zeros(V, feats.shape[-3], feats.shape[-1], feats.shape[-1], dtype=dtype, device=device)
     # pool positive edges
     pos_inds = torch.where(edges[:, 1] > 0)
     pos_v_src = torch.cat([edges[pos_inds[0], 0], edges[pos_inds[0], 2]]).long()
     pos_v_dst = torch.cat([edges[pos_inds[0], 2], edges[pos_inds[0], 0]]).long()
     pos_vecs_src = feats[pos_v_src.contiguous()]
     pos_v_dst = pos_v_dst.view(-1, 1, 1, 1).expand_as(pos_vecs_src).to(device)
     pooled_v_pos = torch.scatter_add(pooled_v_pos, 0, pos_v_dst, pos_vecs_src)
     # pool negative edges
     neg_inds = torch.where(edges[:, 1] < 0)
     neg_v_src = torch.cat([edges[neg_inds[0], 0], edges[neg_inds[0], 2]]).long()
     neg_v_dst = torch.cat([edges[neg_inds[0], 2], edges[neg_inds[0], 0]]).long()
     neg_vecs_src = feats[neg_v_src.contiguous()]
     neg_v_dst = neg_v_dst.view(-1, 1, 1, 1).expand_as(neg_vecs_src).to(device)
     pooled_v_neg = torch.scatter_add(pooled_v_neg, 0, neg_v_dst, neg_vecs_src)
     # update nodes features
     enc_in = torch.cat([feats, pooled_v_pos, pooled_v_neg], 1)
     out = self.encoder(enc_in)
     return out
Example #2
0
 def forward(  # pylint: disable=missing-function-docstring
     self,
     segment_sizes: "torch.Tensor",
     features: "torch.Tensor",
 ) -> "torch.Tensor":
     n_seg = len(segment_sizes)
     encoded_features = self.encoder(features)
     segment_indices = torch.repeat_interleave(
         torch.arange(n_seg, device=features.device),
         segment_sizes.long(),
     )
     n_dim = encoded_features.shape[1]
     segment_sum = torch.scatter_add(
         input=torch.zeros((n_seg, n_dim),
                           dtype=encoded_features.dtype,
                           device=features.device),
         dim=0,
         index=segment_indices.view(-1, 1).expand(-1, n_dim),
         src=encoded_features,
     )
     out = self.norm(segment_sum)
     out = self.layer0(out) + out
     out = self.layer1(out) + out
     out = self.decoder(out).squeeze()
     out = self.sigmoid(out)
     return out
Example #3
0
 def label_weight(this, shape, label):
     weight = torch.zeros(shape).to(label.device) + 0.1
     weight = torch.scatter_add(weight, 0, label,
                                torch.ones_like(label).float())
     weight = 1. / weight
     weight[-1] /= 200
     return weight
Example #4
0
def update_balance_profile(
    balance_dict,
    gate_top_k_idx,
    _gate_score_top_k,
    gate_context,
    layer_idx,
    num_expert,
    balance_strategy,
):
    c_e = torch.scatter_add(
        torch.zeros(num_expert, device=gate_top_k_idx.device),
        0,
        gate_top_k_idx,
        torch.ones_like(gate_top_k_idx, dtype=torch.float),
    )
    for key in metrics:
        balance_dict[key][layer_idx] = metrics[key](c_e)
    S = gate_top_k_idx.shape[0]
    if balance_strategy == "gshard":
        gate_score_all = gate_context
        m_e = torch.sum(F.softmax(gate_score_all, dim=1), dim=0) / S
        balance_dict["gshard_loss"][layer_idx] = torch.sum(
            c_e * m_e) / num_expert / S
    elif balance_strategy == "noisy":
        balance_dict["noisy_loss"][layer_idx] = gate_context
Example #5
0
 def tensor_indexing_ops(self):
     x = torch.randn(2, 4)
     y = torch.randn(4, 4)
     t = torch.tensor([[0, 0], [1, 0]])
     mask = x.ge(0.5)
     i = [0, 1]
     return len(
         torch.cat((x, x, x), 0),
         torch.concat((x, x, x), 0),
         torch.conj(x),
         torch.chunk(x, 2),
         torch.dsplit(torch.randn(2, 2, 4), i),
         torch.column_stack((x, x)),
         torch.dstack((x, x)),
         torch.gather(x, 0, t),
         torch.hsplit(x, i),
         torch.hstack((x, x)),
         torch.index_select(x, 0, torch.tensor([0, 1])),
         x.index(t),
         torch.masked_select(x, mask),
         torch.movedim(x, 1, 0),
         torch.moveaxis(x, 1, 0),
         torch.narrow(x, 0, 0, 2),
         torch.nonzero(x),
         torch.permute(x, (0, 1)),
         torch.reshape(x, (-1, )),
         torch.row_stack((x, x)),
         torch.select(x, 0, 0),
         torch.scatter(x, 0, t, x),
         x.scatter(0, t, x.clone()),
         torch.diagonal_scatter(y, torch.ones(4)),
         torch.select_scatter(y, torch.ones(4), 0, 0),
         torch.slice_scatter(x, x),
         torch.scatter_add(x, 0, t, x),
         x.scatter_(0, t, y),
         x.scatter_add_(0, t, y),
         # torch.scatter_reduce(x, 0, t, reduce="sum"),
         torch.split(x, 1),
         torch.squeeze(x, 0),
         torch.stack([x, x]),
         torch.swapaxes(x, 0, 1),
         torch.swapdims(x, 0, 1),
         torch.t(x),
         torch.take(x, t),
         torch.take_along_dim(x, torch.argmax(x)),
         torch.tensor_split(x, 1),
         torch.tensor_split(x, [0, 1]),
         torch.tile(x, (2, 2)),
         torch.transpose(x, 0, 1),
         torch.unbind(x),
         torch.unsqueeze(x, -1),
         torch.vsplit(x, i),
         torch.vstack((x, x)),
         torch.where(x),
         torch.where(t > 0, t, 0),
         torch.where(t > 0, t, t),
     )
def calc_b_new(layer, b_tilde):
    with T.no_grad():
        act_shift = T.max(T.abs(b_tilde))

        # For b_tilde < 0 replace biases with b_tilde
        b_new = T.scatter(layer.bias, 0,
                          T.where(b_tilde < 0)[0], b_tilde[b_tilde < 0])

        # For b_tilde < 0 update biases
        b_new = T.scatter_add(-T.abs(b_new), 0,
                              T.where(b_tilde < 0)[0],
                              act_shift.expand_as(T.where(b_tilde < 0)[0]))
        # For b_tilde > 0 update biases
        b_new = T.scatter_add(b_new, 0,
                              T.where(b_tilde > 0)[0],
                              act_shift.expand_as(T.where(b_tilde > 0)[0]))

    return b_new, act_shift
Example #7
0
    def compute_context(self, weights, relpos):
        """
        :param weights: (batsize, numheads, qlen, klen)
        :param relpos:  (batsize, qlen, klen, 1)
        :return:    # weighted sum over klen (batsize, numheads, qlen, dimperhead)
        """
        ret = None
        batsize = weights.size(0)
        numheads = weights.size(1)
        qlen = weights.size(2)
        device = weights.device

        # Naive implementation builds matrices of (batsize, numheads, qlen, klen, dimperhead)
        # whereas normal transformer only (batsize, numheads, qlen, klen) and (batsize, numheads, klen, dimperhead)
        for n in range(relpos.size(-1)):
            relpos_ = relpos[:, :, :, n]

            # map relpos_ to compact integer space of unique relpos_ entries
            try:
                relpos_unique = relpos_.unique()
            except Exception as e:
                raise e
            mapper = torch.zeros(
                relpos_unique.max() + 1, device=device, dtype=torch.long
            )  # mapper is relpos_unique but the other way around
            mapper[relpos_unique] = torch.arange(0,
                                                 relpos_unique.size(0),
                                                 device=device).long()
            relpos_mapped = mapper[
                relpos_]  # (batsize, qlen, klen) but ids are from 0 to number of unique relposes

            # sum up the attention weights which refer to the same relpos id
            # scatter: src is weights, index is relpos_mapped[:, None, :, :]
            # scatter: gathered[batch, head, qpos, relpos_mapped[batch, head, qpos, kpos]]
            #               += weights[batch, head, qpos, kpos]
            gathered = torch.zeros(batsize,
                                   numheads,
                                   qlen,
                                   relpos_unique.size(0),
                                   device=device)
            gathered = torch.scatter_add(
                gathered, -1,
                relpos_mapped[:, None, :, :].repeat(batsize, numheads, 1,
                                                    1), weights)
            # --> (batsize, numheads, qlen, numunique): summed attention weights

            # get embeddings and update ret
            embs = self.embv(relpos_unique).view(
                relpos_unique.size(0), numheads,
                -1)  # (numunique, numheads, dimperhead)
            relposemb = torch.einsum("bhqn,nhd->bhqd", gathered, embs)
            if ret is None:
                ret = torch.zeros_like(relposemb)
            ret = ret + relposemb
        return ret
Example #8
0
def compute_degree_distr(data_x, k):

    data_y = data_x
    data_x_len = len(data_x)
    mean_dist_a = torch.zeros(len(data_x), device=device)
    mean_dist_b = torch.zeros(len(data_x), device=device)
    batch_sz = 700
    y_norm = (data_y**2).sum(-1).unsqueeze(0)
    data_y = data_y.t()
    degrees = torch.zeros(data_x_len, device=device)

    for i in range(0, data_x_len, batch_sz):
        j = min(data_x_len, i + batch_sz)
        x = data_x[i:j]
        x_norm = (x**2).sum(-1).unsqueeze(-1)
        cur_dist = x_norm + y_norm - 2 * torch.mm(x, data_y)
        del x_norm
        del x
        #top dist includes 0
        top_dist, ranks = torch.topk(cur_dist, k + 1, largest=False)
        ones = torch.ones(j - i, k + 1, device=device)

        degrees = torch.scatter_add(degrees,
                                    dim=0,
                                    index=ranks.view(-1),
                                    src=ones.view(-1))

        #mean_dist_a[i:j] = (top_dist/k).sum(-1)
        #mean_dist_b[i:j] = (cur_dist/(data_x_len-1)).sum(-1)

    distribution = torch.zeros(data_x_len // 3, device=device)
    ones = torch.ones(data_x_len, device=device)
    distribution = torch.scatter_add(distribution,
                                     dim=0,
                                     index=(degrees - 1).long(),
                                     src=ones)
    pdb.set_trace()
    return distribution
Example #9
0
    def forward_step(self, src_context_output, src_word_output, src_word_mask, 
                     KG_word_output, KG_word_seq, KG_word_mask, hidden, tgt_word, combine_knowledge = False):
        """
            Args:
                src_context_output (FloatTensor) : (batch_size, context_size)
                src_word_output (FloatTensor)    : (batch_size, src_word_len, word_size)
                src_word_mask (FloatTensor)      : (batch_size, src_word_len)
                KG_word_output (FloatTensor)     : (batch_size, KG_max_word_len, KG_word_size)
                KG_word_seq (LongTensor)         : (batch_size, src_max_word_len)
                KG_word_mask (FloatTensor)       : (batch_size, KG_word_len)
                last_hidden (FloatTensor)        : (batch_size, rnn_size)
                tgt_word (LongTensor)            : (batch_size)
            Regurns:
                hidden (FloatTensor)             : (num_layer, batch_size, rnn_size)
                logit_word (FloatTensor)         : (batch_size, num_vocab)
        """
        batch_size = tgt_word.size(0)
        # last_hidden : [batch_size, rnn_size]
        last_hidden = hidden[-1]
        # attn_src_word : [batch_size, src_word_size] 
        attn_src_word,_ = self.attention(last_hidden, src_word_output, src_word_output, src_word_mask)
        # attn_KG_word : [batch_size, src_KG_size] 
        KG_attention_query = last_hidden + attn_src_word
        attn_KG_word, attn_KG_scores = self.attention(KG_attention_query, KG_word_output, KG_word_output, KG_word_mask)
        # tgt_word_emb = [batch_size, emb_size]
        tgt_word_emb = self.word_embedding(tgt_word) 
        # rnn_inputs : [barch_size, emb_size + src_word_size]
        rnn_inputs = torch.cat([tgt_word_emb, attn_src_word, attn_KG_word], dim = 1)
        # rnn_output : [batch_size, rnn_size] ; hidden : [num_layer, batch_size, rnn_size]
        rnn_output, hidden = self.rnn_cell(rnn_inputs, hidden)
        # prob_word : [batch_size, num_vocab]
        prob_word = self.output(rnn_output)
        if combine_knowledge == True:
            # copy_dist : [batch_size, num_vocab]
            copy_prob = torch.zeros(batch_size, self.num_vocab, device = src_context_output.device)
            copy_prob = torch.scatter_add(input = copy_prob, 
                                          dim = 1, 
                                          index = KG_word_seq, src = attn_KG_scores.permute(1, 0))
            # gen_dist_trans_input : [batch_size, emb_size + KG_rnn_size + rnn_size]
            gen_dist_trans_input = torch.cat([tgt_word_emb, attn_KG_word, rnn_output], dim = 1)
            gen_dist = self.gen_dist_trans(gen_dist_trans_input)
            gen_dist = torch.sigmoid(gen_dist)
            # combined_prob_word: [batch_size, num_vocab]
            combined_prob_word = prob_word * gen_dist + (1 - gen_dist) * copy_prob
        else:
            combined_prob_word = prob_word
        # logit_word : [batch_size, num_vocab]
        logit_word = combined_prob_word.log()

        return hidden, logit_word
Example #10
0
def sparseMatmul(A: torch.Tensor,
                 x: torch.Tensor,
                 sizeA: Union[List[int], Tuple[int, int]],
                 indexA=None) -> torch.Tensor:
    """
    :param A: <Tensor t, 3> m*n的矩阵的稀疏表示
    :param x: <Tensor n>
    :param sizeA: A的大小,即[m,n]
    :param indexA: 可选,是A[:, 0:2].to(torch.int64)
    :return: A*x(矩阵乘法)的结果 <Tensor m>
    """
    indexA = indexA if indexA is not None else A[:, 0:2].to(torch.int64)
    item1 = A[:, 2] * x[indexA[:, 1]]
    return torch.scatter_add(
        torch.zeros(sizeA[0], dtype=x.dtype, device=x.device), 0, indexA[:, 0],
        item1)
    def forward(self, x):
        if self.metric_type == MetricType.minkowski:
            k = torch.cdist(x, self.train_data, p=self.p, compute_mode="donot_use_mm_for_euclid_dist")
        elif self.metric_type == MetricType.wminkowski:
            k = torch.cdist(self.w * x, self.train_data, p=self.p, compute_mode="donot_use_mm_for_euclid_dist")
        elif self.metric_type == MetricType.seuclidean:
            k = torch.cdist(self.V * x, self.train_data, p=2, compute_mode="donot_use_mm_for_euclid_dist")
        elif self.metric_type == MetricType.mahalanobis:
            # We use the Cholesky decomposition to calculate the Mahalanobis distance
            # Mahalanobis distance d^2(x, x') = (x - x')T VI (x - x')
            # using Cholesky decomposition we have VI = LT L
            # then:
            #                      d^2(x, x') = (x - x')T (LT L) (x - x')
            #                                 = (Lx - Lx')T (Lx - Lx')
            k = torch.cdist(torch.mm(x, self.L), self.train_data, p=2, compute_mode="donot_use_mm_for_euclid_dist")

        d, k = torch.topk(k, self.n_neighbors, dim=1, largest=False)
        output = torch.index_select(self.train_labels, 0, k.view(-1)).view(-1, self.n_neighbors)

        if self.weights == "distance":
            d = torch.pow(d, -1)
            inf_mask = torch.isinf(d)
            inf_row = torch.any(inf_mask, axis=1)
            d[inf_row] = inf_mask[inf_row].float()
        else:
            d = torch.ones_like(k, dtype=torch.float32)

        if self.classification:
            # classification
            output = torch.scatter_add(self.proba_tensor, 1, output, d)
            proba_sum = output.sum(1, keepdim=True)
            proba_sum = torch.where(proba_sum == 0, self.one_tensor, proba_sum)
            output = torch.pow(proba_sum, -1) * output

            if self.perform_class_select:
                return torch.index_select(self.classes, 0, torch.argmax(output, dim=1)), output
            else:
                return torch.argmax(output, dim=1), output
        else:
            # regression
            output = d * output
            if self.weights != "distance":
                output = output.sum(1) / self.n_neighbors
            else:
                denom = d.sum(1)
                output = output.sum(1) / denom
            return output
Example #12
0
def scatter_add(device: Type[draw_devices],
                token_sizes: Type[draw_token_sizes],
                dim: Type[draw_embedding_dims], *, timer: TimerSuit):
    device = device()
    token_size, num = token_sizes(), token_sizes()
    if num > token_size:
        token_size, num = num, token_size
    in_dim = dim()

    inputs = torch.randn((token_size, in_dim),
                         requires_grad=True,
                         device=device)
    index1 = torch.randint(0, num, (token_size, ), device=device)
    index2 = index1[:, None].expand_as(inputs)

    with timer.rua_forward:
        actual = rua.scatter_add(tensor=inputs, index=index1)

    with timer.naive_forward:
        excepted = torch.scatter_add(
            torch.zeros((num, in_dim), device=device),
            src=inputs,
            index=index2,
            dim=0,
        )

    with timer.rua_backward:
        torch.autograd.grad(
            actual,
            inputs,
            torch.ones_like(actual),
            create_graph=False,
        )

    with timer.naive_backward:
        torch.autograd.grad(
            excepted,
            inputs,
            torch.ones_like(excepted),
            create_graph=False,
        )
Example #13
0
 def forward(self, input_ids, input_context, context_mask, decode_input,
             decode_target, use_beam_search, beam_width):
     """
     :param input_ids: 用于解码增强的输入ids序列
     :param input_context: 编码的context (bsz, enc_seq, dim)
     :param context_mask: 沿用encoder部分的input_mask, 将pad的输入忽略
     :param decode_input: 解码输入 ==> 训练时才有
     :param decode_target: 解码目标 ==> 训练时才有, 测试时为空
     :param use_beam_search: 是否启动beam search解码
     :param beam_width: beam宽度
     :return: 训练时返回损失, 测试时返回解码序列
     """
     bsz = input_context.size(0)
     net_state = self.init_hidden_unit.repeat(bsz, 1)
     if decode_target is not None:
         dec_list = []
         decode_emb = self.embedding_matrix(
             decode_input)  # 作为输入的一部分(bsz, dec_seq, dim)
         for i in range(self.seq_len):
             # step1: 通过注意力机制获取当前的context_rep
             attn_score = torch.einsum("bsd,bd->bs", input_context,
                                       net_state)
             attn_score.mul_(self.scale)
             attn_score += (1.0 - context_mask) * (-1e30)
             attn_prob = torch.softmax(attn_score, dim=-1)
             attn_vec = torch.einsum("bs,bsd->bd", attn_prob, input_context)
             # step2: 更新状态
             x = torch.cat([attn_vec, decode_emb[:, i, :], net_state],
                           dim=-1)
             reset_sig = self.reset_gate(x)
             update_sig = self.update_gate(x)
             update_value = self.update(
                 torch.cat(
                     [attn_vec, decode_emb[:, i, :], reset_sig * net_state],
                     dim=-1))
             net_state = (
                 1 - update_sig) * net_state + update_sig * update_value
             # step3: 计算分布概率--> mos
             vocab_prob_list = []
             # pi_k = self.pi_mos(net_state)
             # for k in range(args["mos"]):
             #     output = self.output[k](net_state)
             #     vocab_logits = torch.nn.functional.linear(input=output, weight=self.embedding_matrix.weight)
             #     vocab_prob_list.append(torch.softmax(vocab_logits, dim=-1)[..., None])
             # vocab_prob = torch.einsum("bk,bvk->bv", pi_k, torch.cat(vocab_prob_list, dim=-1))
             output = self.output(net_state)
             vocab_logits = torch.nn.functional.linear(
                 input=output, weight=self.embedding_matrix.weight)
             vocab_prob = torch.softmax(vocab_logits, dim=-1)
             input_logits = torch.einsum("bd,bsd->bs",
                                         self.copy_output(net_state),
                                         input_context)
             input_logits += (1.0 - context_mask) * (-1e30)
             input_prob = torch.softmax(input_logits,
                                        dim=-1)  # (bsz, enc_seq)
             # step4: 根据mode_sig混合两个概率
             mode_sig = self.mode_select(net_state)
             vocab_prob = vocab_prob * mode_sig
             vocab_prob = torch.scatter_add(vocab_prob,
                                            dim=1,
                                            index=input_ids,
                                            src=input_prob * (1 - mode_sig))
             dec_list.append(vocab_prob[:, None, :])
         # 计算损失
         predict = torch.cat(dec_list, dim=1)  # (bsz, dec_seq, vocab)
         predict = predict.view(size=(-1, predict.size(-1)))
         decode_target = decode_target.view(size=(-1, ))
         predict = torch.gather(predict,
                                dim=1,
                                index=decode_target[:,
                                                    None]).squeeze(dim=-1)
         init_loss = -torch.log(predict + self.epsilon)
         init_loss *= (decode_target != 0).float()
         loss = torch.sum(init_loss) / torch.nonzero(decode_target != 0,
                                                     as_tuple=False).size(0)
         return loss[None].repeat(bsz)
     else:
         if use_beam_search:
             pass
         else:  # 贪婪式解码
             dec_list = []
             for i in range(self.seq_len):
                 # step1: 通过注意力机制获取当前的context_rep
                 attn_score = torch.einsum("bsd,bd->bs", input_context,
                                           net_state)
                 attn_score.mul_(self.scale)
                 attn_score += (1.0 - context_mask) * (-1e30)
                 attn_prob = torch.softmax(attn_score, dim=-1)
                 attn_vec = torch.einsum("bs,bsd->bd", attn_prob,
                                         input_context)
                 # step2: 更新状态
                 if i == 0:
                     emb = self.embedding_matrix(
                         torch.full(size=(bsz, ),
                                    fill_value=args["start_token_id"],
                                    dtype=torch.int32).long().to(device))
                 else:
                     emb = self.embedding_matrix(
                         dec_list[i - 1].squeeze(dim=-1))
                 x = torch.cat([attn_vec, emb, net_state], dim=-1)
                 reset_sig = self.reset_gate(x)
                 update_sig = self.update_gate(x)
                 update_value = self.update(
                     torch.cat([attn_vec, emb, reset_sig * net_state],
                               dim=-1))
                 net_state = (
                     1 - update_sig) * net_state + update_sig * update_value
                 # step3: 计算分布得分
                 vocab_prob_list = []
                 # pi_k = self.pi_mos(net_state)
                 # for k in range(args["mos"]):
                 #     output = self.output[k](net_state)
                 #     vocab_logits = torch.nn.functional.linear(input=output, weight=self.embedding_matrix.weight)
                 #     vocab_prob_list.append(torch.softmax(vocab_logits, dim=-1)[..., None])
                 # vocab_prob = torch.einsum("bk,bvk->bv", pi_k, torch.cat(vocab_prob_list, dim=-1))
                 output = self.output(net_state)
                 vocab_logits = torch.nn.functional.linear(
                     input=output, weight=self.embedding_matrix.weight)
                 vocab_prob = torch.softmax(vocab_logits, dim=-1)
                 input_logits = torch.einsum("bd,bsd->bs",
                                             self.copy_output(net_state),
                                             input_context)
                 input_logits += (1.0 - context_mask) * (-1e30)
                 input_prob = torch.softmax(input_logits,
                                            dim=-1)  # (bsz, enc_seq)
                 # step4: 根据mode_sig混合两个概率
                 mode_sig = self.mode_select(net_state)
                 vocab_prob = vocab_prob * mode_sig
                 vocab_prob = torch.scatter_add(vocab_prob,
                                                dim=1,
                                                index=input_ids,
                                                src=input_prob *
                                                (1 - mode_sig))
                 dec_list.append(torch.argmax(vocab_prob, dim=-1)[:, None])
             return torch.cat(dec_list, dim=-1)
Example #14
0
 def _impl(x, dim, index, src):
     dim = dim.item()
     return (torch.scatter_add(x, dim, index, src), )
 def forward(self, input_ids, encoder_rep, input_mask, decode_input,
             decode_target, use_beam_search, beam_width):
     bsz = input_ids.size(0)
     if decode_input is not None:  # 代表训练模式
         input_ids = input_ids[:, None, :].repeat(1, self.seq_len, 1)
         decode_embed = self.drop(self.word_emb(decode_input))
         all_ones = decode_embed.new_ones((self.seq_len, self.seq_len),
                                          dtype=torch.uint8)
         dec_attn_mask = torch.tril(all_ones, diagonal=0)[:, :, None, None]
         pos_seq = torch.arange(self.seq_len - 1,
                                -1,
                                -1.0,
                                device=device,
                                dtype=decode_embed.dtype)
         pos_embed = self.drop(self.pos_emb(pos_seq))
         core_out = decode_embed.transpose(0, 1).contiguous()
         enc_rep_t = encoder_rep.transpose(0, 1).contiguous()
         enc_mask_t = input_mask.transpose(0, 1).contiguous()
         for layer in self.layers:
             core_out = layer(dec_inp=core_out,
                              r=pos_embed,
                              enc_inp=enc_rep_t,
                              dec_mask=dec_attn_mask,
                              enc_mask=enc_mask_t)
         core_out = self.drop(core_out.transpose(
             0, 1).contiguous())  # (bsz, dec_len, dim)
         output = self.output(core_out)
         vocab_logits = torch.nn.functional.linear(
             input=output, weight=self.word_emb.weight)
         vocab_prob = torch.softmax(vocab_logits, dim=-1)
         input_logits = torch.einsum("bid,bjd->bij",
                                     self.copy_output(core_out),
                                     encoder_rep)  # (bsz, dec_len, enc_len)
         input_logits = input_logits + (1.0 - input_mask[:, None, :].repeat(
             1, self.seq_len, 1)) * (-1e30)
         input_prob = torch.softmax(input_logits,
                                    dim=-1)  # (bsz, dec_len, enc_len)
         mode_sig = self.mode_select(core_out)  # (bsz, dec_len, 1)
         vocab_prob = vocab_prob * mode_sig
         vocab_prob = torch.scatter_add(vocab_prob,
                                        dim=2,
                                        index=input_ids,
                                        src=(1 - mode_sig) * input_prob)
         vocab_prob = vocab_prob.view(-1, args["vocab_size"])
         decode_target = decode_target.view(-1)
         predict = torch.gather(vocab_prob,
                                dim=1,
                                index=decode_target[:,
                                                    None]).squeeze(dim=-1)
         init_loss = -torch.log(predict + self.epsilon)
         init_loss *= (decode_target != 0).float()
         loss = torch.sum(init_loss) / torch.nonzero(decode_target != 0,
                                                     as_tuple=False).size(0)
         # 为了并行化设计, 将loss变成(bsz,)
         return loss[None].repeat(bsz)
     else:  # 代表验证或者测试解码模式 ==> 比较耗时
         if not use_beam_search:  # 使用贪心搜索 ==> 验证集
             dec_list = []
             decode_ids = torch.full(size=(bsz, 1),
                                     fill_value=args["start_token_id"],
                                     dtype=torch.int32).long().to(device)
             for i in range(1, self.seq_len + 1):
                 if i > 1:
                     decode_ids = torch.cat([decode_ids, dec_list[i - 2]],
                                            dim=-1)
                 decode_embed = self.word_emb(decode_ids)
                 all_ones = decode_embed.new_ones((i, i), dtype=torch.uint8)
                 dec_attn_mask = torch.tril(all_ones,
                                            diagonal=0)[:, :, None, None]
                 pos_seq = torch.arange(i - 1,
                                        -1,
                                        -1.0,
                                        device=device,
                                        dtype=decode_embed.dtype)
                 pos_embed = self.pos_emb(pos_seq)
                 core_out = decode_embed.transpose(0, 1).contiguous()
                 enc_rep_t = encoder_rep.transpose(0, 1).contiguous()
                 enc_mask_t = input_mask.transpose(0, 1).contiguous()
                 for layer in self.layers:
                     core_out = layer(dec_inp=core_out,
                                      r=pos_embed,
                                      enc_inp=enc_rep_t,
                                      dec_mask=dec_attn_mask,
                                      enc_mask=enc_mask_t)
                 core_out = core_out.transpose(0, 1).contiguous()[:, -1, :]
                 output = self.output(core_out)
                 vocab_logits = torch.nn.functional.linear(
                     input=output, weight=self.word_emb.weight)
                 vocab_prob = torch.softmax(vocab_logits, dim=-1)
                 input_logits = torch.einsum("bd,bjd->bj",
                                             self.copy_output(core_out),
                                             encoder_rep)  # (bsz, enc_len)
                 input_logits = input_logits + (1.0 - input_mask) * (-1e30)
                 input_prob = torch.softmax(input_logits,
                                            dim=-1)  # (bsz, enc_len)
                 mode_sig = self.mode_select(core_out)  # (bsz, 1)
                 vocab_prob = vocab_prob * mode_sig
                 vocab_prob = torch.scatter_add(vocab_prob,
                                                dim=1,
                                                index=input_ids,
                                                src=(1 - mode_sig) *
                                                input_prob)
                 dec_list.append(torch.argmax(vocab_prob, dim=-1)[:, None])
             return torch.cat(dec_list, dim=-1)
         else:  # 使用集束搜索
             # 扩展成beam_width * bsz
             """
             需要注意: 1. trigram-block的使用 ==> 出现重复直接加上-1e9(需要考虑end_token边界=>只在边界范围内使用)
             2. 长度惩罚, 考虑end_token边界
             """
             decode_ids = torch.full(size=(bsz * beam_width, 1),
                                     fill_value=args["start_token_id"],
                                     dtype=torch.int32).long().to(device)
             input_ids = input_ids.repeat((beam_width, 1))
             encoder_rep = encoder_rep.repeat((beam_width, 1, 1))
             input_mask = input_mask.repeat((beam_width, 1))
             dec_topK_log_probs = [0] * (beam_width * bsz
                                         )  # (bsz*beam)  每个序列的当前log概率和
             dec_topK_sequences = [[] for _ in range(beam_width * bsz)
                                   ]  # (bsz*beam, seq_len) 解码id序列
             dec_topK_seq_lens = [1] * (
                 beam_width * bsz
             )  # 解码序列长度 ==> 加上一个偏置项1, 防止进行长度惩罚时出现div 0的情况
             for i in range(1, self.seq_len + 1):
                 if i > 1:
                     input_decode_ids = torch.cat([
                         decode_ids,
                         torch.tensor(dec_topK_sequences).long().to(device)
                     ],
                                                  dim=-1)
                 else:
                     input_decode_ids = decode_ids
                 decode_embed = self.word_emb(input_decode_ids)
                 all_ones = decode_embed.new_ones((i, i), dtype=torch.uint8)
                 dec_attn_mask = torch.tril(all_ones,
                                            diagonal=0)[:, :, None, None]
                 pos_seq = torch.arange(i - 1,
                                        -1,
                                        -1.0,
                                        device=device,
                                        dtype=decode_embed.dtype)
                 pos_embed = self.pos_emb(pos_seq)
                 core_out = decode_embed.transpose(0, 1).contiguous()
                 enc_rep_t = encoder_rep.transpose(0, 1).contiguous()
                 enc_mask_t = input_mask.transpose(0, 1).contiguous()
                 for layer in self.layers:
                     core_out = layer(dec_inp=core_out,
                                      r=pos_embed,
                                      enc_inp=enc_rep_t,
                                      dec_mask=dec_attn_mask,
                                      enc_mask=enc_mask_t)
                 core_out = core_out.transpose(0, 1).contiguous()[:, -1, :]
                 output = self.output(core_out)
                 vocab_logits = torch.nn.functional.linear(
                     input=output, weight=self.word_emb.weight)
                 vocab_prob = torch.softmax(vocab_logits, dim=-1)
                 input_logits = torch.einsum(
                     "bd,bjd->bj", self.copy_output(core_out),
                     encoder_rep)  # (bsz*beam, enc_len)
                 input_logits = input_logits + (1.0 - input_mask) * (-1e30)
                 input_prob = torch.softmax(input_logits,
                                            dim=-1)  # (bsz*beam, enc_len)
                 mode_sig = self.mode_select(core_out)  # (bsz*beam, 1)
                 vocab_prob = vocab_prob * mode_sig
                 vocab_prob = torch.scatter_add(
                     vocab_prob,
                     dim=1,
                     index=input_ids,
                     src=(1 - mode_sig) * input_prob)  # (bsz*beam, vocab)
                 vocab_logp = torch.log(vocab_prob +
                                        self.epsilon)  # 取对数, 加eps
                 """ step1: 检查是否存在trigram blocking重叠, 只需要检查最后一项和之前项是否存在重叠即可 """
                 if i > 4:  # 当序列长度大于等于4时才有意义, 或者当前解码时刻大于4时才有检查的必要
                     for j in range(beam_width * bsz):
                         trigram_blocks = []
                         for k in range(3, i):
                             if dec_topK_sequences[j][
                                     k - 1] == args["end_token_id"]:
                                 break
                             trigram_blocks.append(
                                 dec_topK_sequences[j][k - 3:k])
                         if len(trigram_blocks) > 1 and trigram_blocks[
                                 -1] in trigram_blocks[:-1]:
                             dec_topK_log_probs[j] += -1e9
                 """ step2: 为每个样本, 选择topK个序列 ==> 类似于重构dec_topK_sequences"""
                 for j in range(bsz):
                     topK_vocab_logp = vocab_logp[j::bsz]  # (k, vocab)
                     candidate_list = []
                     """ 容易出错的地方, i=1的时候不需要为每个K生成K个候选,否则beam search将完全沦为greedy search """
                     for k in range(beam_width):
                         ind = j + k * bsz
                         if args["end_token_id"] in dec_topK_sequences[ind]:
                             candidate_list.append({
                                 "add_logit":
                                 0,
                                 "add_seq_len":
                                 0,
                                 "affiliate_k":
                                 k,
                                 "add_token_id":
                                 args["end_token_id"],
                                 "sort_logits":
                                 dec_topK_log_probs[ind] /
                                 (dec_topK_seq_lens[ind]**
                                  args["beam_length_penalty"])
                             })
                         else:
                             k_logps, k_indices = topK_vocab_logp[k].topk(
                                 dim=0, k=beam_width)
                             k_logps, k_indices = k_logps.cpu().numpy(
                             ), k_indices.cpu().numpy()
                             for l in range(beam_width):
                                 aff = l if i == 1 else k
                                 candidate_list.append({
                                     "add_logit":
                                     k_logps[l],
                                     "add_seq_len":
                                     1,
                                     "affiliate_k":
                                     aff,
                                     "add_token_id":
                                     k_indices[l],
                                     "sort_logits":
                                     (dec_topK_log_probs[ind] + k_logps[l])
                                     / ((dec_topK_seq_lens[ind] + 1)**
                                        args["beam_length_penalty"])
                                 })
                         if i == 1:  ## 当解码第一个词的时候只能考虑一个
                             break
                     candidate_list.sort(key=lambda x: x["sort_logits"],
                                         reverse=True)
                     candidate_list = candidate_list[:beam_width]
                     """ 序列修正, 更新topK """
                     c_dec_topK_sequences, c_dec_topK_log_probs, c_dec_topK_seq_lens = \
                         deepcopy(dec_topK_sequences), deepcopy(dec_topK_log_probs), deepcopy(dec_topK_seq_lens)
                     for k in range(beam_width):
                         ind = bsz * candidate_list[k]["affiliate_k"] + j
                         r_ind = bsz * k + j
                         father_seq, father_logits, father_len = c_dec_topK_sequences[
                             ind], c_dec_topK_log_probs[
                                 ind], c_dec_topK_seq_lens[ind]
                         dec_topK_sequences[r_ind] = father_seq + [
                             candidate_list[k]["add_token_id"]
                         ]
                         dec_topK_log_probs[
                             r_ind] = father_logits + candidate_list[k][
                                 "add_logit"]
                         dec_topK_seq_lens[
                             r_ind] = father_len + candidate_list[k][
                                 "add_seq_len"]
             return torch.tensor(dec_topK_sequences[:bsz]).long().to(device)
Example #16
0
def map_edge_lists(edge_lists: list,
                   perform_unique=True,
                   known_node_ids=None,
                   sequential_train_nodes=False,
                   sequential_deg_nodes=0):
    print("Remapping Edges")

    defined_edges = []
    for edge_list in edge_lists:
        if edge_list is not None:
            defined_edges.append(edge_list)

    edge_lists = defined_edges

    if isinstance(edge_lists[0], pd.DataFrame):
        if isinstance(edge_lists[0].iloc[0][0], str):
            # need to take uniques using pandas for string datatypes, since torch doesn't support strings
            return map_edge_list_dfs(edge_lists, known_node_ids,
                                     sequential_train_nodes,
                                     sequential_deg_nodes)

        new_edge_lists = []
        for edge_list in edge_lists:
            new_edge_lists.append(dataframe_to_tensor(edge_list))

        edge_lists = new_edge_lists

    all_edges = torch.cat(edge_lists)

    has_rels = False
    num_rels = 1
    unique_rels = torch.empty([0])
    mapped_rel_ids = torch.empty([0])
    if all_edges.size(1) == 3:
        has_rels = True

    output_dtype = torch.int32

    if perform_unique:
        unique_src = torch.unique(all_edges[:, 0])
        unique_dst = torch.unique(all_edges[:, -1])
        if known_node_ids is None:
            unique_nodes = torch.unique(torch.cat([unique_src, unique_dst]),
                                        sorted=True)
        else:
            unique_nodes = torch.unique(torch.cat([unique_src, unique_dst] +
                                                  known_node_ids),
                                        sorted=True)

        num_nodes = unique_nodes.size(0)

        if has_rels:
            unique_rels = torch.unique(all_edges[:, 1], sorted=True)
            num_rels = unique_rels.size(0)
    else:
        num_nodes = torch.max(all_edges[:, 0])[0]
        unique_nodes = torch.arange(num_nodes).to(output_dtype)

        if has_rels:
            num_rels = torch.max(all_edges[:, 1])[0]
            unique_rels = torch.arange(num_rels).to(output_dtype)

    if sequential_train_nodes or sequential_deg_nodes > 0:
        seq_nodes = None

        if sequential_train_nodes and sequential_deg_nodes <= 0:
            print("Sequential Train Nodes")
            seq_nodes = known_node_ids[0]
        else:
            out_degrees = torch.zeros([
                num_nodes,
            ], dtype=torch.int32)
            out_degrees = torch.scatter_add(
                out_degrees, 0,
                torch.squeeze(edge_lists[0][:, 0]).to(torch.int64),
                torch.ones([
                    edge_lists[0].shape[0],
                ], dtype=torch.int32))

            in_degrees = torch.zeros([
                num_nodes,
            ], dtype=torch.int32)
            in_degrees = torch.scatter_add(
                in_degrees, 0,
                torch.squeeze(edge_lists[0][:, -1]).to(torch.int64),
                torch.ones([
                    edge_lists[0].shape[0],
                ], dtype=torch.int32))

            degrees = in_degrees + out_degrees

            deg_argsort = torch.argsort(degrees, dim=0, descending=True)
            high_degree_nodes = deg_argsort[:sequential_deg_nodes]

            print("High Deg Nodes Degree Sum:",
                  torch.sum(degrees[high_degree_nodes]).numpy())

            if sequential_train_nodes and sequential_deg_nodes > 0:
                print("Sequential Train and High Deg Nodes")
                seq_nodes = torch.unique(
                    torch.cat([high_degree_nodes, known_node_ids[0]]))
                seq_nodes = seq_nodes.index_select(
                    0, torch.randperm(seq_nodes.size(0), dtype=torch.int64))
                print("Total Seq Nodes: ", seq_nodes.shape[0])
            else:
                print("Sequential High Deg Nodes")
                seq_nodes = high_degree_nodes

        seq_mask = torch.zeros(num_nodes, dtype=torch.bool)
        seq_mask[seq_nodes.to(torch.int64)] = True
        all_other_nodes = torch.arange(num_nodes, dtype=seq_nodes.dtype)
        all_other_nodes = all_other_nodes[~seq_mask]

        mapped_node_ids = -1 * torch.ones(num_nodes, dtype=output_dtype)
        mapped_node_ids[seq_nodes.to(torch.int64)] = torch.arange(
            seq_nodes.shape[0], dtype=output_dtype)
        mapped_node_ids[all_other_nodes.to(
            torch.int64)] = seq_nodes.shape[0] + torch.randperm(
                num_nodes - seq_nodes.shape[0], dtype=output_dtype)
    else:
        mapped_node_ids = torch.randperm(num_nodes, dtype=output_dtype)

    if has_rels:
        mapped_rel_ids = torch.randperm(num_rels, dtype=output_dtype)

    # TODO may use too much memory if the max id is very large
    # Needed to support indexing w/ the remap
    if torch.max(unique_nodes) + 1 > num_nodes:
        extended_map = torch.zeros(torch.max(unique_nodes) + 1,
                                   dtype=output_dtype)
        extended_map[unique_nodes] = mapped_node_ids
    else:
        extended_map = mapped_node_ids

    all_edges = None  # can safely free this tensor

    output_edge_lists = []
    for edge_list in edge_lists:
        new_src = extended_map[edge_list[:, 0].to(torch.int64)]
        new_dst = extended_map[edge_list[:, -1].to(torch.int64)]

        if has_rels:
            new_rel = mapped_rel_ids[edge_list[:, 1].to(torch.int64)]
            output_edge_lists.append(
                torch.stack([new_src, new_rel, new_dst], dim=1))
        else:
            output_edge_lists.append(torch.stack([new_src, new_dst], dim=1))

    node_mapping = np.stack(
        [unique_nodes.numpy(), mapped_node_ids.numpy()], axis=1)
    rel_mapping = None
    if has_rels:
        rel_mapping = np.stack(
            [unique_rels.numpy(), mapped_rel_ids.numpy()], axis=1)

    return output_edge_lists, node_mapping, rel_mapping
 def forward(self, input_ids, encoder_rep, input_mask, decode_input,
             decode_target, use_beam_search, beam_width):
     bsz = input_ids.size(0)
     if decode_input is not None:  # 代表训练模式
         input_ids = input_ids[:, None, :].repeat(1, self.seq_len, 1)
         decode_embed = self.drop(self.word_emb(decode_input))
         all_ones = decode_embed.new_ones((self.seq_len, self.seq_len),
                                          dtype=torch.uint8)
         dec_attn_mask = torch.tril(all_ones, diagonal=0)[:, :, None, None]
         pos_seq = torch.arange(self.seq_len - 1,
                                -1,
                                -1.0,
                                device=device,
                                dtype=decode_embed.dtype)
         pos_embed = self.drop(self.pos_emb(pos_seq))
         core_out = decode_embed.transpose(0, 1).contiguous()
         enc_rep_t = encoder_rep.transpose(0, 1).contiguous()
         enc_mask_t = input_mask.transpose(0, 1).contiguous()
         for layer in self.layers:
             core_out = layer(dec_inp=core_out,
                              r=pos_embed,
                              enc_inp=enc_rep_t,
                              dec_mask=dec_attn_mask,
                              enc_mask=enc_mask_t)
         core_out = self.drop(core_out.transpose(
             0, 1).contiguous())  # (bsz, dec_len, dim)
         output = self.output(core_out)
         vocab_logits = torch.nn.functional.linear(
             input=output, weight=self.word_emb.weight)
         vocab_prob = torch.softmax(vocab_logits, dim=-1)
         input_logits = torch.einsum("bid,bjd->bij",
                                     self.copy_output(core_out),
                                     encoder_rep)  # (bsz, dec_len, enc_len)
         input_logits = input_logits + (1.0 - input_mask[:, None, :].repeat(
             1, self.seq_len, 1)) * (-1e30)
         input_prob = torch.softmax(input_logits,
                                    dim=-1)  # (bsz, dec_len, enc_len)
         mode_sig = self.mode_select(core_out)  # (bsz, dec_len, 1)
         vocab_prob = vocab_prob * mode_sig
         vocab_prob = torch.scatter_add(vocab_prob,
                                        dim=2,
                                        index=input_ids,
                                        src=(1 - mode_sig) * input_prob)
         vocab_prob = vocab_prob.view(-1, args["vocab_size"])
         decode_target = decode_target.view(-1)
         predict = torch.gather(vocab_prob,
                                dim=1,
                                index=decode_target[:,
                                                    None]).squeeze(dim=-1)
         init_loss = -torch.log(predict + self.epsilon)
         init_loss *= (decode_target != 0).float()
         loss = torch.sum(init_loss) / torch.nonzero(decode_target != 0,
                                                     as_tuple=False).size(0)
         # 为了并行化设计, 将loss变成(bsz,)
         return loss[None].repeat(bsz)
     else:  # 代表验证或者测试解码模式 ==> 比较耗时
         dec_list = []
         decode_ids = torch.full(size=(bsz, 1),
                                 fill_value=args["start_token_id"],
                                 dtype=torch.int32).long().to(device)
         for i in range(1, self.seq_len + 1):
             if i > 1:
                 decode_ids = torch.cat([decode_ids, dec_list[i - 2]],
                                        dim=-1)
             decode_embed = self.word_emb(decode_ids)
             all_ones = decode_embed.new_ones((i, i), dtype=torch.uint8)
             dec_attn_mask = torch.tril(all_ones, diagonal=0)[:, :, None,
                                                              None]
             pos_seq = torch.arange(i - 1,
                                    -1,
                                    -1.0,
                                    device=device,
                                    dtype=decode_embed.dtype)
             pos_embed = self.pos_emb(pos_seq)
             core_out = decode_embed.transpose(0, 1).contiguous()
             enc_rep_t = encoder_rep.transpose(0, 1).contiguous()
             enc_mask_t = input_mask.transpose(0, 1).contiguous()
             for layer in self.layers:
                 core_out = layer(dec_inp=core_out,
                                  r=pos_embed,
                                  enc_inp=enc_rep_t,
                                  dec_mask=dec_attn_mask,
                                  enc_mask=enc_mask_t)
             core_out = core_out.transpose(0, 1).contiguous()[:, -1, :]
             output = self.output(core_out)
             vocab_logits = torch.nn.functional.linear(
                 input=output, weight=self.word_emb.weight)
             vocab_prob = torch.softmax(vocab_logits, dim=-1)
             input_logits = torch.einsum("bd,bjd->bj",
                                         self.copy_output(core_out),
                                         encoder_rep)  # (bsz, enc_len)
             input_logits = input_logits + (1.0 - input_mask) * (-1e30)
             input_prob = torch.softmax(input_logits,
                                        dim=-1)  # (bsz, enc_len)
             mode_sig = self.mode_select(core_out)  # (bsz, 1)
             vocab_prob = vocab_prob * mode_sig
             vocab_prob = torch.scatter_add(vocab_prob,
                                            dim=1,
                                            index=input_ids,
                                            src=(1 - mode_sig) * input_prob)
             dec_list.append(torch.argmax(vocab_prob, dim=-1)[:, None])
         return torch.cat(dec_list, dim=-1)
def model(x, y, alt_av, alt_ids_cuda):
    # global parameters in the model
    if diagonal_alpha:
        alpha_mu = pyro.sample(
            "alpha",
            dist.Normal(torch.zeros(len(non_mix_params), device=x.device),
                        1).to_event(1))
    else:
        alpha_mu = pyro.sample(
            "alpha",
            dist.MultivariateNormal(
                torch.zeros(len(non_mix_params), device=x.device),
                scale_tril=torch.tril(
                    1 * torch.eye(len(non_mix_params), device=x.device))))

    if diagonal_beta_mu:
        beta_mu = pyro.sample(
            "beta_mu",
            dist.Normal(torch.zeros(len(mix_params), device=x.device),
                        1.).to_event(1))
    else:
        beta_mu = pyro.sample(
            "beta_mu",
            dist.MultivariateNormal(
                torch.zeros(len(mix_params), device=x.device),
                scale_tril=torch.tril(
                    1 * torch.eye(len(mix_params), device=x.device))))

    # Vector of variances for each of the d variables
    theta = pyro.sample(
        "theta",
        dist.HalfCauchy(
            10. * torch.ones(len(mix_params), device=x.device)).to_event(1))
    # Lower cholesky factor of a correlation matrix
    eta = 1. * torch.ones(
        1, device=x.device
    )  # Implies a uniform distribution over correlation matrices
    L_omega = pyro.sample("L_omega",
                          dist.LKJCorrCholesky(len(mix_params), eta))
    # Lower cholesky factor of the covariance matrix
    L_Omega = torch.mm(torch.diag(theta.sqrt()), L_omega)

    # local parameters in the model
    random_params = pyro.sample(
        "beta_resp",
        dist.MultivariateNormal(beta_mu.repeat(num_resp, 1),
                                scale_tril=L_Omega).to_event(1))

    # vector of respondent parameters: global + local (respondent)
    params_resp = torch.cat([alpha_mu.repeat(num_resp, 1), random_params],
                            dim=-1)

    # vector of betas of MXL (may repeat the same learnable parameter multiple times; random + fixed effects)
    beta_resp = torch.cat([
        params_resp[:, beta_to_params_map[i]] for i in range(num_alternatives)
    ],
                          dim=-1)

    with pyro.plate("locals", len(x), subsample_size=BATCH_SIZE) as ind:

        with pyro.plate("data_resp", T):
            # compute utilities for each alternative
            utilities = torch.scatter_add(
                zeros_vec[:, ind, :], 2,
                alt_ids_cuda[ind, :, :].transpose(0, 1),
                torch.mul(x[ind, :, :].transpose(0, 1), beta_resp[ind, :]))

            # adjust utility for unavailable alternatives
            utilities += alt_av[ind, :, :].transpose(0, 1)

            # likelihood
            pyro.sample("obs",
                        dist.Categorical(logits=utilities),
                        obs=y[ind, :].transpose(0, 1))
Example #19
0
    def backward(ctx, grad_output):
        featureH, featureL, label, centers, margins, gamma = ctx.saved_tensors
        num_pair = featureH.shape[0]
        num_class = centers.shape[0]
        centers_normed = F.normalize(centers, dim=1)  # Cx1024
        featureH_normed = F.normalize(featureH, dim=1)  # N'x1024
        distmatH = torch.mm(featureH_normed, centers_normed.t())  # N'xC
        featureL_normed = F.normalize(featureL, dim=1)
        distmatL = torch.mm(featureL_normed, centers_normed.t())
        mask = distmatH.new_zeros(distmatH.size(), dtype=torch.long)
        mask.scatter_add_(
            1,
            label.long().unsqueeze(1),
            torch.ones(num_pair, 1, device=mask.device, dtype=torch.long))

        distHic = distmatH[mask == 1]
        distLic = distmatL[mask == 1]
        distHicL, hard_index_batch = torch.max(distmatH[mask == 0].view(
            num_pair, num_class - 1),
                                               dim=1)
        hard_index_batch[hard_index_batch >= label] += 1
        li1 = torch.acos(distHic) - torch.acos(distLic) + margins[0]
        li2 = torch.acos(distHic) - torch.acos(distHicL) + margins[1]

        centers_normed_batch = centers_normed.index_select(0, label.long())
        hard_normed_batch = centers_normed.index_select(0, hard_index_batch)

        d = -(1 - distHic.pow(2)).pow(-0.5)
        e = -(1 - distHicL.pow(2)).pow(-0.5)
        f = -(1 - distLic.pow(2)).pow(-0.5)
        I = torch.eye(featureH.shape[1], device=d.device)
        xcH = (I - torch.einsum('bi,bj->bij',
                                (featureH_normed, featureH_normed))
               ) / featureH.norm(dim=1, keepdim=True).unsqueeze(-1)
        xcL = (I - torch.einsum('bi,bj->bij',
                                (featureL_normed, featureL_normed))
               ) / featureL.norm(dim=1, keepdim=True).unsqueeze(-1)
        cc = (I - torch.einsum('bi,bj->bij', (centers_normed, centers_normed))
              ) / centers.norm(dim=1, keepdim=True).unsqueeze(-1)
        d = d.unsqueeze(1)
        e = e.unsqueeze(1)
        f = f.unsqueeze(1)

        counts_h = centers.new_ones(num_class)  # (C,)
        counts_hl = centers.new_ones(num_class)  # (C,)
        counts_c = centers.new_ones(num_class)  # (C,)
        ones_h = centers.new_ones(num_pair)  # (N',)
        ones_h[li2 <= 0] = 0
        ones_c = centers.new_ones(num_pair)  # (N',)
        ones_c[li1 <= 0] = 0
        grad_centers = centers.new_zeros(centers.size())  # Cx1024

        counts_h.scatter_add_(0, label.long(), ones_h)
        counts_hl.scatter_add_(0, hard_index_batch, ones_h)
        counts_c.scatter_add_(0, label.long(), ones_c)

        grad_centers_h = featureH_normed * d
        grad_centers_h[li2 <= 0] = 0
        grad_centers += torch.scatter_add(
            centers.new_zeros(centers.size()), 0,
            label.unsqueeze(1).expand(featureH_normed.size()).long(),
            grad_centers_h) / counts_h.unsqueeze(-1)

        grad_centers_hl = -featureH_normed * e
        grad_centers_hl[li2 <= 0] = 0
        grad_centers += torch.scatter_add(
            centers.new_zeros(centers.size()), 0,
            hard_index_batch.unsqueeze(1).expand(featureH_normed.size()),
            grad_centers_hl) / counts_hl.unsqueeze(-1)

        grad_centers_c = featureH_normed * d - featureL_normed * f
        grad_centers_c[li1 <= 0] = 0
        grad_centers += torch.scatter_add(
            centers.new_zeros(centers.size()), 0,
            label.unsqueeze(1).expand(featureH_normed.size()).long(),
            grad_centers_c * gamma) / counts_c.unsqueeze(-1)

        grad_centers /= num_pair

        grad = centers_normed_batch * d - hard_normed_batch * e
        grad[li2 <= 0] = 0

        grad_h = centers_normed_batch * d
        grad_h[li1 <= 0] = 0

        grad_h = grad_output * (grad + grad_h * gamma) / num_pair

        grad_l = -centers_normed_batch * f
        grad_l[li1 <= 0] = 0
        grad_l = grad_output * grad_l * gamma / num_pair

        return torch.bmm(xcH, grad_h.unsqueeze(-1)).squeeze(-1), torch.bmm(
            xcL, grad_l.unsqueeze(-1)).squeeze(-1), None, torch.bmm(
                cc, grad_centers.unsqueeze(-1)).squeeze(-1), None, None
Example #20
0
    def forward(self, src_context_output, src_word_output, src_word_len, 
                KG_word_output, KG_word_len, KG_word_seq, tgt_word_input, combine_knowledge = False):
        """Compute decoder scores from context_output.
        Args:
            src_context_output (FloatTensor) : (batch_size, context_size)
            src_word_output (FloatTensor)    : (batch_size, src_max_word_len, word_size)
            src_word_len (LongTensor)        : (batch_size)
            KG_word_output (FloatTensor)     : (batch_size, KG_max_word_len, KG_word_size)
            KG_word_len (LongTensor)         : (batch_size)
            KG_word_seq (LongTensor)         : (batch_size, src_max_word_len)
            tgt_word_input (LongTensor)      : (batch_size, tgt_word_len)
        Regurns:
            logit (FloatTensor)              : (batch_size, tgt_word_len, num_vocab)
            converage (FloatTensor)          : (batch_size, tgt_word_len)
        """
        assert src_context_output.size(0) == src_word_output.size(0)
        assert src_context_output.size(0) == tgt_word_input.size(0)
        assert src_context_output.size(0) == KG_word_output.size(0)  
        batch_size = src_context_output.size(0)
        max_src_len = src_word_output.size(1)
        max_KG_len = KG_word_output.size(1)
        max_tgt_len = tgt_word_input.size(1)

        # prepare the source word and KG word outputs
        # src_word_output : [src_word_len, batch_size, src_word_size]
        src_word_output = src_word_output.permute(1, 0, 2)
        # src_word_mask : [max_src_len, batch_size]
        src_word_mask = generate_mask_by_length(src_word_len, max_src_len) 
        # KG_word_output : [KG_word_len, batch_size, KG_word_size]
        KG_word_output = KG_word_output.permute(1, 0, 2)
        # KG_word_mask : [max_KG_len, batch_size]
        KG_word_mask = generate_mask_by_length(KG_word_len, max_KG_len) 

        # obtain word embedding and initial hidden states
        # tgt_word_emb : [batch_size, tgt_word_len, emb_size]
        tgt_word_emb = self.word_embedding(tgt_word_input)
        # hidden : [batch_size, num_layer, rnn_size]
        hidden = self.context2hidden(src_context_output).view(batch_size, self.num_layers, self.rnn_size)
        # hidden : [num_layer, batch_size, rnn_size]
        hidden = hidden.permute(1, 0, 2)
         
        logit_word_list = []
        coverage_list = []
        for word_index in range(max_tgt_len):
            # recurrence
            # last_hidden : [batch_size, rnn_size]
            last_hidden = hidden[-1]
            # attn_src_word : [batch_size, src_word_size] 
            attn_src_word,_ = self.attention(last_hidden, src_word_output, src_word_output, src_word_mask)
            # attn_KG_word : [batch_size, src_KG_size] 
            # attn_KG_scores : [max_KG_len ,batch_size]
            KG_attention_query = last_hidden + attn_src_word
            attn_KG_word, attn_KG_scores = self.attention(KG_attention_query, KG_word_output, KG_word_output, KG_word_mask)
            # rnn_inputs : [barch_size, emb_size + src_word_size + KG_word_size]
            rnn_inputs = torch.cat([tgt_word_emb[:,word_index], attn_src_word, attn_KG_word], dim = 1)
            # rnn_output : [batch_size, rnn_size] ; hidden : [num_layer, batch_size, rnn_size]
            rnn_output, hidden = self.rnn_cell(rnn_inputs, hidden)
            # prob_word : [batch_size, num_vocab]
            prob_word = self.output(rnn_output)
            if combine_knowledge == True:
                # copy_dist : [batch_size, num_vocab]
                copy_prob = torch.zeros(batch_size, self.num_vocab, device = src_context_output.device)
                copy_prob = torch.scatter_add(input = copy_prob, 
                                              dim = 1, 
                                              index = KG_word_seq, src = attn_KG_scores.permute(1, 0))
                # gen_dist_trans_input : [batch_size, emb_size + KG_rnn_size + rnn_size]
                gen_dist_trans_input = torch.cat([tgt_word_emb[:, word_index], attn_KG_word, rnn_output], dim = 1)
                gen_dist = self.gen_dist_trans(gen_dist_trans_input)
                gen_dist = torch.sigmoid(gen_dist)
                # combined_prob_word: [batch_size, num_vocab]
                combined_prob_word = prob_word * gen_dist + (1 - gen_dist) * copy_prob
            else:
                combined_prob_word = prob_word
            # logit_word : [batch_size, num_vocab]
            logit_word = combined_prob_word.log()
            # coverage_score : [batch_size]
            coverage_score = cos_sim(last_hidden, hidden[-1])
            coverage_score = F.relu(coverage_score)
            
            logit_word_list.append(logit_word)
        # logit : [batch_size, max_tgt_len, num_vocab]
        logit = torch.stack(logit_word_list, dim = 1)

        return logit
def chamfer_loss(xyz1, xyz2, batch1=None, batch2=None, reduce='mean'):
    """
    Calculates the Chamfer distance between two batches of point clouds.

    :param xyz1:
        a point cloud of shape ``(b, n1, k)`` or ``(bn1, k)``.
    :param xyz2:
        a point cloud of shape (b, n2, k) or (bn2, k).
    :param batch1:
        batch indicator of shape ``(bn1,)`` for each point in `xyz1`.
        If ``None``, all points are assumed to be in the same point cloud.
        For batched point cloud, this ``None`` must be passed.
    :param batch2:
        batch indicator of shape ``(bn2,)`` for each point in `xyz2`.
        If ``None``, all points are assumed to be in the same point cloud.
        For batched point cloud, this ``None`` must be passed.
    :param reduce:
        ``'mean'`` or ``'sum'``. Default: ``'mean'``.
    :return:
        the Chamfer distance between the inputs.
    """
    assert xyz1.ndim == xyz2.ndim, 'two point clouds do not have the same number of dimensions'
    assert len(xyz1.shape) in (2, 3) and len(
        xyz2.shape) in (2, 3), 'unknown shape of tensors'
    assert xyz1.shape[-1] == xyz2.shape[-1], 'mismatched feature dimension'
    assert reduce in ('mean', 'sum'), 'Unknown reduce method'

    if xyz1.dim() == 3:
        assert batch1 is None and batch2 is None, 'batch indicators must be None when point clouds are 3D tensors'
        assert xyz1.shape[0] == xyz2.shape[0], 'mismatched batch dimension'

        batch_idx = T.arange(xyz1.shape[0], device=xyz1.device)
        batch_idx = batch_idx[:, None]
        batch1 = T.zeros(*xyz1.shape[:-1], device=xyz1.device, dtype=T.long)
        batch1 += batch_idx
        batch1 = batch1.flatten()
        batch2 = T.zeros(*xyz2.shape[:-1], device=xyz2.device, dtype=T.long)
        batch2 += batch_idx
        batch2 = batch2.flatten()

        xyz1 = xyz1.view(-1, 3)
        xyz2 = xyz2.view(-1, 3)

    dist1, dist2, deg1, deg2 = stack_chamfer_distance(xyz1, xyz2, batch1,
                                                      batch2)
    if batch1 is None:
        loss_2 = T.mean(dist2)
        loss_1 = T.mean(dist1)
    else:
        loss_1 = T.zeros_like(deg1).to(dist1.dtype)
        loss_1 = T.scatter_add(loss_1, 0, batch1, dist1)
        loss_1 = loss_1 / deg1

        loss_2 = T.zeros_like(deg2).to(dist2.dtype)
        loss_2 = T.scatter_add(loss_2, 0, batch2, dist2)
        loss_2 = loss_2 / deg2

        reduce = T.sum if reduce == 'sum' else T.mean
        loss_1 = reduce(loss_1)
        loss_2 = reduce(loss_2)

    return loss_1 + loss_2
Example #22
0
    def forward(self, src_graph_vecs, graph_batch, graph_tensors, init_atoms, orders):
        batch_size = len(orders)

        hgraph = HTuple(
            node=src_graph_vecs.new_zeros(graph_tensors[0].size(0), self.hidden_size),
            mess=self.rnn_cell.get_init_state(graph_tensors[1]),
            vmask=self.itensor.new_zeros(graph_tensors[0].size(0)),
            emask=self.itensor.new_zeros(graph_tensors[1].size(0))
        )
        # We assume that there is no edge directly connecting two initial subsgraphs

        all_topo_preds, all_atom_preds, all_bond_preds = [], [], []
        graph_tensors = self.mpn.embed_graph(graph_tensors) + (graph_tensors[-1],)  # preprocess graph tensors

        visited = set()
        new_atoms = [a for alist in init_atoms for a in alist]
        self.update_graph_mask(graph_batch, new_atoms, hgraph, visited)

        maxt = max([len(x) for x in orders])
        for t in range(maxt):
            batch_list = [i for i in range(batch_size) if t < len(orders[i])]
            assert hgraph.vmask[0].item() == 0 and hgraph.emask[0].item() == 0

            cur_graph_tensors = self.apply_graph_mask(graph_tensors, hgraph)
            vmask = hgraph.vmask.unsqueeze(-1).float()
            hgraph.node, hgraph.mess = self.mpn.encoder(*cur_graph_tensors[:-1], mask=vmask)

            new_atoms = []
            for i in batch_list:
                xid, yid, front, fbond = orders[i][t]
                st, le = graph_tensors[-1][i]
                stop = 1 if yid is None else 0

                gvec = hgraph.node[st: st + le].sum(dim=0)
                cxt_vec = torch.cat((gvec, hgraph.node[xid]), dim=-1)
                all_topo_preds.append((cxt_vec, i, stop))
                # print('xvec', hgraph.mess[cur_graph_tensors[2][xid]].sum(dim=-1))
                # print('xvec', cur_graph_tensors[1][cur_graph_tensors[2][xid]].nonzero())

                if stop == 0:
                    new_atoms.append(yid)
                    ylabel = graph_batch.nodes[yid]['label']
                    atom_type = self.avocab[ylabel]
                    all_atom_preds.append((cxt_vec, i, atom_type))

                    hist = torch.zeros_like(hgraph.node[xid])  # avoid inplace operation
                    atom_vec = self.E_a[atom_type]
                    assert front[0] == xid
                    for zid, bt in zip(front, fbond):
                        cur_hnode = torch.cat([hist, atom_vec], dim=-1)
                        cur_hnode = self.R_bond(cur_hnode)
                        pairs = torch.cat([gvec, cur_hnode, hgraph.node[zid]], dim=-1)
                        all_bond_preds.append((pairs, i, bt))
                        if bt > 0:
                            bond_vec = self.E_b[bt]
                            z_hnode = torch.cat([hgraph.node[zid], bond_vec], dim=-1)
                            hist += self.W_bond(z_hnode)

            self.update_graph_mask(graph_batch, new_atoms, hgraph, visited)

        topo_vecs, batch_idx, topo_labels = zip_tensors(all_topo_preds)
        topo_scores = self.get_topo_score(src_graph_vecs, batch_idx, topo_vecs)
        topo_loss = self.topo_loss(topo_scores, topo_labels.float())
        topo_acc = get_accuracy_bin(topo_scores, topo_labels)
        topo_loss_batch = torch.zeros(batch_size).cuda()
        # for idx, bid in enumerate(batch_idx):
        #     topo_loss_batch[bid] += topo_loss[idx]
        topo_loss_batch = torch.scatter_add(topo_loss_batch, dim=0, index=batch_idx, src=topo_loss)
        # print('topo_loss_batch', topo_loss_batch, topo_loss_batch.shape)
        # print(topo_scores)
        # print(topo_labels)

        atom_vecs, batch_idx, atom_labels = zip_tensors(all_atom_preds)
        atom_scores = self.get_atom_score(src_graph_vecs, batch_idx, atom_vecs)
        atom_loss = self.atom_loss(atom_scores, atom_labels)
        atom_acc = get_accuracy(atom_scores, atom_labels)
        atom_loss_batch = torch.zeros(batch_size).cuda()
        atom_loss_batch = torch.scatter_add(atom_loss_batch, dim=0, index=batch_idx, src=atom_loss)
        # for idx, bid in enumerate(batch_idx):
        #     atom_loss_batch[bid] += atom_loss[idx]
        # print('atom_loss_batch', atom_loss_batch, atom_loss_batch.shape)
        # print(atom_scores.max(dim=-1))
        # print(atom_labels)

        bond_vecs, batch_idx, bond_labels = zip_tensors(all_bond_preds)
        bond_scores = self.get_bond_score(src_graph_vecs, batch_idx, bond_vecs)
        bond_loss = self.bond_loss(bond_scores, bond_labels)
        bond_acc = get_accuracy(bond_scores, bond_labels)
        bond_loss_batch = torch.zeros(batch_size).cuda()
        # for idx, bid in enumerate(batch_idx):
        #     bond_loss_batch[bid] += bond_loss[idx]
        # bond_loss_batch_v2 = torch.zeros(batch_size).cuda()
        bond_loss_batch = torch.scatter_add(bond_loss_batch, dim=0, index=batch_idx, src=bond_loss)
        # assert torch.sum(torch.abs(bond_loss_batch_v2 - bond_loss_batch)) < 1e-6
        # bond_loss_batch[batch_idx] += bond_loss
        # print('bond_loss_batch', bond_loss_batch, bond_loss_batch.shape)
        # print(bond_labels)
        loss = topo_loss_batch + atom_loss_batch + bond_loss_batch
        # print(atom_loss.sum(), topo_loss.sum(), bond_loss.sum())
        # print(atom_loss_batch.sum(), topo_loss_batch.sum(), bond_loss_batch.sum())
        # loss = atom_loss + topo_loss + bond_loss
        return loss, atom_acc, topo_acc, bond_acc