Example #1
def VarGRUCell(input,
    input = input.expand(
        3, *
        input.size()) if noise_in is None else input.unsqueeze(0) * noise_in
    hx = hidden.expand(3, *hidden.size(
    )) if noise_hidden is None else hidden.unsqueeze(0) * noise_hidden

    gi = torch.baddbmm(b_ih.unsqueeze(1), input, w_ih)
    gh = torch.baddbmm(b_hh.unsqueeze(1), hx, w_hh)
    i_r, i_i, i_n = gi
    h_r, h_i, h_n = gh

    resetgate = torch.sigmoid(i_r + h_r)
    inputgate = torch.sigmoid(i_i + h_i)
    newgate = torch.tanh(i_n + resetgate * h_n)
    hy = newgate + inputgate * (hidden - newgate)

    return hy
Example #2
    def forward(self, hidden, encoder_outputs):
        lengths = None
        if type(encoder_outputs) is PackedSequence:
            encoder_outputs, lengths = pad(encoder_outputs, batch_first=True)
            lengths = [len(x) for x in encoder_outputs]

        batch_size = encoder_outputs.size()[0]
        attns = cuda(T.zeros(batch_size, max(lengths)), gpu_id=self.gpu_id)
        lengths = cuda(T.zeros(max(lengths), 1), gpu_id=self.gpu_id)

        if self.method == 'dot':
            attns = T.baddbmm(lengths, encoder_outputs,
                              hidden.transpose(2, 1)).squeeze(2)

        elif self.method == 'general':
            attended = self.attn(encoder_outputs)
            attns = T.baddbmm(lengths, attended,
                              hidden.transpose(2, 1)).squeeze(2)

        elif self.method == 'concat':
            concated = T.cat(
                (hidden.expand_as(encoder_outputs), encoder_outputs), 2)
            energy = self.attn(concated)
            expanded = self.other.unsqueeze(0).expand(batch_size, 1,
            attns = T.baddbmm(lengths, energy,
                              expanded.transpose(2, 1)).squeeze(2)

        return F.softmax(attns).unsqueeze(1)
Example #3
def var_lstm_cell(input: Tensor, hidden: Tuple[Tensor, Tensor], w_ih: Tensor, w_hh: Tensor,
                  b_ih: Tensor = None, b_hh: Tensor = None, noise_in: Tensor = None, noise_hidden: Tensor = None) \
        -> Tuple[Tensor, Tensor]:
    input = input.expand(
        4, *
        input.size()) if noise_in is None else input.unsqueeze(0) * noise_in

    hx, cx = hidden
    hx = hx.expand(
        4, *
        hx.size()) if noise_hidden is None else hx.unsqueeze(0) * noise_hidden

    gates = torch.add(torch.baddbmm(b_ih.unsqueeze(1), input, w_ih),
                      torch.baddbmm(b_hh.unsqueeze(1), hx, w_hh))

    ingate, forgetgate, cellgate, outgate = gates

    ingate = torch.sigmoid(ingate)
    forgetgate = torch.sigmoid(forgetgate)
    cellgate = torch.tanh(cellgate)
    outgate = torch.sigmoid(outgate)

    cy = torch.add(torch.mul(forgetgate, cx), torch.mul(ingate, cellgate))
    hy = torch.mul(outgate, torch.tanh(cy))

    return hy, cy
Example #4
    def update_stack(self, V, s, vt, ut, dt, t):
        # batch x k x stack dim

        zeromat = torch.zeros(V.shape)
        zeromat[:, t, :] = vt
        Vp = V + zeromat

        prod = torch.baddbmm(
            self.matrix_partial_sums[:, 0:s.shape[1], 0:s.shape[1]],
        zero_vals = torch.zeros(s.shape)

        prod = prod.squeeze()
        new_strength = ut.repeat(1, s.shape[1]) - prod

        sp = torch.max(zero_vals, s - torch.max(zero_vals, new_strength))
        #sp = torch.cat((sp, dt), dim = 1)
        sp[:, t] = dt[:, 0]
        rt = torch.zeros(self.batch_size, self.stack_dim)
        inner_max_vec = self.relu(torch.ones(prod.shape) - prod - sp)
        matrix_A = torch.min(sp, inner_max_vec).unsqueeze(1)

        rt = torch.baddbmm(torch.zeros(self.batch_size, 1, self.stack_dim),
                           matrix_A, Vp).squeeze()

        return Vp, sp, rt
    def test_baddbmm(self):
        rand_seed = int(get_rand_seed())
        print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed))
        for i in range(8, 12, 2):
            for j in range(8, 12, 2):
                alpha = i / 10
                beta = j / 10
                batches = 2

                M, N, O = 23, 8, 12
                x_auto_mix_a = torch.randn(batches, M, N, dtype=torch.float32, device=device)
                x_auto_mix_b = torch.randn(batches, N, O, dtype=torch.float32, device=device)
                add_auto_mix = torch.randn(batches, M, O, dtype=torch.float32, device=device)

                x_man_bf16_a = x_auto_mix_a.to(torch.bfloat16)
                x_man_bf16_b = x_auto_mix_b.to(torch.bfloat16)
                add_man_bf16 = add_auto_mix.to(torch.bfloat16)

                with AutoDNNL(True), AutoMixPrecision(False):
                    res_man_bf16 = torch.baddbmm(add_man_bf16, x_man_bf16_a, x_man_bf16_b, beta=beta, alpha=alpha)
                    with AutoMixPrecision(True):
                        res_auto_mix = torch.baddbmm(add_auto_mix, x_auto_mix_a, x_auto_mix_b, beta=beta, alpha=alpha)
                        self.assertEqual(res_auto_mix.dtype, torch.float)
                        self.assertEqual(res_auto_mix, res_man_bf16.float())
Example #6
        def calculate_convolution(weight_gate_list, bias_gate_list,
                                  weight_list, bias_list, index, d_kind,
                                  adj_matrix, h):
            if self.gate_flag:
                z = torch.baddbmm(bias_gate_list[index].expand(
                    [current_batch_size, -1, -1]),
                                      [current_batch_size, -1, -1]),
                gate = self.gate_activity(z)
                gate = 1

            relation = torch.baddbmm(
                bias_list[index].expand([current_batch_size, -1, -1]),
                batch1=weight_list[index].expand([current_batch_size, -1, -1]),

            if hasattr(self, 'dropout'):
                relation = self.dropout(relation)

            relation = self.norm_item * gate * relation

            relation = torch.bmm(relation, adj_matrix[:, d_kind])
            return relation
def SkipConnectLSTMCell(input,
    input = input.expand(
        4, *
        input.size()) if noise_in is None else input.unsqueeze(0) * noise_in

    hx, cx = hidden
    hx = torch.cat([hx, hidden_skip], dim=1)
    hx = hx.expand(
        4, *
        hx.size()) if noise_hidden is None else hx.unsqueeze(0) * noise_hidden

    gates = torch.baddbmm(b_ih.unsqueeze(1), input, w_ih) + torch.baddbmm(
        b_hh.unsqueeze(1), hx, w_hh)

    ingate, forgetgate, cellgate, outgate = gates

    ingate = F.sigmoid(ingate)
    forgetgate = F.sigmoid(forgetgate)
    cellgate = F.tanh(cellgate)
    outgate = F.sigmoid(outgate)

    cy = (forgetgate * cx) + (ingate * cellgate)
    hy = outgate * F.tanh(cy)

    return hy, cy
Example #8
 def reordered(self, q, k, v, mask, head_m, **kw):
     cfg = self.cfg
     yo = self.get_y_opts(**kw)
     alpha = 1.0
     if cfg.scale:
         alpha /= float(v.size(-1))**0.5
     if cfg.scale_by_inv:
         alpha /= float(self.lay_i + 1)
     b, n, n_q, d = q.size()
     _, _, n_k, _ = k.size()
     a = torch.empty(b * n, n_q, n_k, dtype=torch.float32, device=q.device)
     if is_amp:
         with autocast(enabled=False):
             q, k = q.reshape(-1, n_q,
                              d), k.transpose(-1, -2).reshape(-1, d, n_k)
             a = torch.baddbmm(a, q.float(), k.float(), beta=0, alpha=alpha)
             a = a.reshape(b, n, n_q, n_k)
         q, k = q.reshape(-1, n_q, d), k.transpose(-1,
                                                   -2).reshape(-1, d, n_k)
         a = torch.baddbmm(a, q.float(), k.float(), beta=0, alpha=alpha)
         a = a.reshape(b, n, n_q, n_k)
     if not self.is_cross:
         m = self.bias[:, :, n_k - n_q:n_k, :n_k].bool()
         a = torch.where(m, a, self.bias_m.to(a.dtype))
     if mask is not None:
         a = a + mask
     a = self.drop_attn(F.softmax(a, dim=-1).type(v.dtype))
     if head_m is not None:
         a = a * head_m
     y = (torch.matmul(a, v), )
     if yo.attn:
         y += (a, )
     return y
def bpdist(feature, data_format='NCW'):
    """Compute pairwise (square) distances of features.
    Based on $(x-y)^2=x^2+y^2-2xy$.

        feature (torch.Tensor): (batch_size, channels, num_inst)
        data_format (str): the format of features. [NCW/NWC]

        distance (torch.Tensor): (batch_size, num_inst, num_inst)

        This method returns square distances, and is optimized for lower memory and faster speed.
        Square sum is more efficient than gather diagonal from inner product.
        The result is somehow inaccurate compared to directly using $(x-y)^2$.

    assert data_format in ('NCW', 'NWC')
    if data_format == 'NCW':
        square_sum = torch.sum(feature**2, 1, keepdim=True)
        square_sum = square_sum.transpose(1, 2) + square_sum
        distance = torch.baddbmm(square_sum,
                                 feature.transpose(1, 2),
        square_sum = torch.sum(feature**2, 2, keepdim=True)
        square_sum = square_sum.transpose(1, 2) + square_sum
        distance = torch.baddbmm(square_sum,
                                 feature.transpose(1, 2),
    return distance
Example #10
def SkipConnectGRUCell(input,
    input = input.expand(
        3, *
        input.size()) if noise_in is None else input.unsqueeze(0) * noise_in
    hx = torch.cat([hidden, hidden_skip], dim=1)
    hx = hx.expand(
        3, *
        hx.size()) if noise_hidden is None else hx.unsqueeze(0) * noise_hidden

    gi = torch.baddbmm(b_ih.unsqueeze(1), input, w_ih)
    gh = torch.baddbmm(b_hh.unsqueeze(1), hx, w_hh)
    i_r, i_i, i_n = gi
    h_r, h_i, h_n = gh

    resetgate = F.sigmoid(i_r + h_r)
    inputgate = F.sigmoid(i_i + h_i)
    newgate = F.tanh(i_n + resetgate * h_n)
    hy = newgate + inputgate * (hidden - newgate)

    return hy
def bpdist2(feature1, feature2, data_format='NCW'):
    """Compute pairwise (square) distances of features.

        feature1 (torch.Tensor): (batch_size, channels, num_inst1)
        feature2 (torch.Tensor): (batch_size, channels, num_inst2)
        data_format (str): the format of features. [NCW/NWC]

        distance (torch.Tensor): (batch_size, num_inst1, num_inst2)

    assert data_format in ('NCW', 'NWC')
    if data_format == 'NCW':
        square_sum1 = torch.sum(feature1**2, 1, keepdim=True)
        square_sum2 = torch.sum(feature2**2, 1, keepdim=True)
        square_sum = square_sum1.transpose(1, 2) + square_sum2
        distance = torch.baddbmm(square_sum,
                                 feature1.transpose(1, 2),
        square_sum1 = torch.sum(feature1**2, 2, keepdim=True)
        square_sum2 = torch.sum(feature2**2, 2, keepdim=True)
        square_sum = square_sum1 + square_sum2.transpose(1, 2)
        distance = torch.baddbmm(square_sum,
                                 feature2.transpose(1, 2),
    return distance
    def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None):
        # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
        bsz, num_heads, q_seq_len, dk = query.size()
        _, _, k_seq_len, _ = key.size()

        # Preallocate attn_weights for `baddbmm`
        attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device)

        # Compute Scale Factor
        scale_factor = 1.0
        if self.scale_attn_weights:
            scale_factor /= float(value.size(-1)) ** 0.5

        if self.scale_attn_by_inverse_layer_idx:
            scale_factor /= float(self.layer_idx + 1)

        # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk))
        if is_amp_available:
            with autocast(enabled=False):
                q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
                attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
                attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)
            q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len)
            attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor)
            attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len)

        if not self.is_cross_attention:
            # if only "normal" attention layer implements causal mask
            query_length, key_length = query.size(-2), key.size(-2)
            causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool()
            mask_value = torch.finfo(attn_weights.dtype).min
            # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
            # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
            mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
            attn_weights = torch.where(causal_mask, attn_weights, mask_value)

        if attention_mask is not None:
            # Apply the attention mask
            attn_weights = attn_weights + attention_mask

        attn_weights = nn.functional.softmax(attn_weights, dim=-1)

        # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise
        if attn_weights.dtype != torch.float32:
            raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32")
        attn_weights = attn_weights.type(value.dtype)
        attn_weights = self.attn_dropout(attn_weights)

        # Mask heads if we want to
        if head_mask is not None:
            attn_weights = attn_weights * head_mask

        attn_output = torch.matmul(attn_weights, value)

        return attn_output, attn_weights
Example #13
 def forward(self, z):
     res = torch.baddbmm(self.bias.repeat(z.shape[0], 1).unsqueeze(2),
                        self.w.unsqueeze(0).repeat(z.shape[0], 1, 1).transpose(1,2), 
                        z.unsqueeze(2)) #+ b.repeat(z.shape[0]).
     res = torch.baddbmm(z.unsqueeze(2),
     res = res.squeeze()
     self.current_det = self.logdet_jacobian(z)
     return res
Example #14
 def logdet_jacobian(self, z):
     det = torch.baddbmm(self.bias.repeat(z.shape[0], 1).unsqueeze(2),
                        self.w.unsqueeze(0).repeat(z.shape[0], 1, 1).transpose(1,2), 
                        z.unsqueeze(2)) #+ b.repeat(z.shape[0]).
     det = torch.bmm(1 - torch.pow(torch.tanh(det),  2),
                     self.w.unsqueeze(0).repeat(z.shape[0], 1, 1).transpose(1,2))
     det = torch.baddbmm(torch.ones(z.shape[0], 1, 1),
                         det, self.u.unsqueeze(0).repeat(z.shape[0],1,1))
     det = det.squeeze().abs()
     return det
Example #15
def orthogonal(points, calibrations, transforms=None):
    Compute the orthogonal projections of 3D points into the image plane by given projection matrix
    :param points: [B, 3, N] Tensor of 3D points
    :param calibrations: [B, 4, 4] Tensor of projection matrix
    :param transforms: [B, 2, 3] Tensor of image transform matrix
    :return: xyz: [B, 3, N] Tensor of xyz coordinates in the image plane
    rot = calibrations[:, :3, :3]
    trans = calibrations[:, :3, 3:4]
    pts = torch.baddbmm(trans, rot, points)  # [B, 3, N]
    if transforms is not None:
        scale = transforms[:2, :2]
        shift = transforms[:2, 2:3]
        pts[:, :2, :] = torch.baddbmm(shift, scale, pts[:, :2, :])
    return pts
Example #16
def gram_matrix(input_tensor, device=torch.device("cuda")):
    Compute Gram matrix

    :param input_tensor: input tensor with shape
     (batch_size, nbr_channels, height, width)
    :return: Gram matrix of y
    (b, ch, h, w) = input_tensor.size()
    features = input_tensor.view(b, ch, w * h)
    features_t = features.transpose(1, 2)

    # more efficient and formal way to avoid underflow for mixed precision training
    input = torch.zeros(b, ch, ch).type(features.type()).to(device)
    gram = torch.baddbmm(input,
                         alpha=1. / (ch * h * w),

    # naive way to avoid underflow for mixed precision training
    # features = features / (ch * h)
    # gram = features.bmm(features_t) / w

    # for fp32 training, it is also safe to use the following:
    # gram = features.bmm(features_t) / (ch * h * w)

    return gram
Example #17
    def backward(self, grad_output):

        grad_input1 = torch.zeros(self.input1.size())
        grad_input1 = torch.baddbmm(grad_input1, torch.transpose(grad_output.view(-1, self.height*self.width*self.depth, 3), 1,2), self.batchgrid.view(-1, self.height*self.width*self.depth,4))

        return grad_input1
Example #18
def batched_all_pairs_squared_l2_dist(
    a: FloatTensorType,
    b: FloatTensorType,
) -> FloatTensorType:
    """For each batch, return the squared L2 distance between each pair of vectors

    Let A and B be tensors of shape NxM_AxD and NxM_BxD, each containing N*M_A
    and N*M_B vectors of dimension D grouped in N batches of size M_A and M_B.
    For each batch, for each vector of A and each vector of B, return the sum
    of the squares of the differences of their components.

    num_chunks, num_a, dim = match_shape(a, -1, -1, -1)
    num_b = match_shape(b, num_chunks, -1, dim)
    a_squared = a.norm(dim=-1).pow(2)
    b_squared = b.norm(dim=-1).pow(2)
    # Calculate res_i,k = sum_j((a_i,j - b_k,j)^2) for each i and k as
    # sum_j(a_i,j^2) - 2 sum_j(a_i,j b_k,j) + sum_j(b_k,j^2), by using a matrix
    # multiplication for the ab part, adding the b^2 as part of the baddbmm call
    # and the a^2 afterwards.
    res = torch.baddbmm(
        b_squared.unsqueeze(-2), a, b.transpose(-2, -1), alpha=-2
    match_shape(res, num_chunks, num_a, num_b)
    return res
Example #19
    def forward(self, left_in, right_in):
        left, right = bundle(left_in), bundle(right_in)
        lstm_gates = self.left(left.h)
        lstm_gates += self.right(right.h)

        # Core composition
        # Retrieve hidden state from left_in and feed it into weight matrix
        cell_inp = []
        hidden_dim = self.dim * self.dim

        h = left.h.contiguous().view(-1, self.dim, self.dim)
        cell_inp = torch.matmul(self.weight, h)
        cell_inp = F.tanh(torch.add(cell_inp, self.b1))

        # Retrieve hidden state from right_in
        h = right.h.contiguous().view(-1, self.dim, self.dim)
        cell_inp = F.tanh(torch.baddbmm(self.b2, cell_inp, h))
        cell_inp = cell_inp.view(-1, hidden_dim)

        out = unbundle(

        return out
Example #20
 def set_full_solution_batched(self):
     # Combines essential boundary conditions and solution of equation system
     self.full_solution_torch = torch.baddbmm(
             self.solution_torch.shape[0], -1, -1),
             self.solution_torch.shape[0], -1, -1), self.solution_torch)
Example #21
    def prediction_layer(self, S_i, u_i, instg_S_i, q_j, qb_j, for_pred=False):

        if for_pred:
            q_j = q_j.squeeze()
            qb_j = qb_j.squeeze()

            # long-term
            res = u_i.mm(q_j.t()) + qb_j

            # short-term
            res += instg_S_i.mm(q_j.t())

            # item-item
            rel_score = torch.matmul(S_i, q_j.t().unsqueeze(0))
            rel_score = torch.sum(rel_score, dim=1)
            res += rel_score
            # long-term
            res = torch.baddbmm(qb_j, q_j, u_i.unsqueeze(2)).squeeze()

            # short-term
            res += torch.bmm(instg_S_i.unsqueeze(1), q_j.permute(0, 2,

            # item-item
            rel_score = S_i.bmm(q_j.permute(0, 2, 1))
            rel_score = torch.sum(rel_score, dim=1)
            res += rel_score

        return res
Example #22
def batch_linear(x, W, b=None):
    """Computes y_i = x_i W_i + b_i where i is each observation index.

    This is similar to `torch.nn.functional.linear`, but a version that
    supports a different W for each observation.

    x: has shape [obs, in_dims]
    W: has shape [obs, out_dims, in_dims]
    b: has shape [out_dims]
    if x.size()[1] != W.size()[-1]:
        raise ValueError(
            f'the in_dim of x ({x.size()[1]}) does not match in_dim of W ({W.size()[-1]})')

    if x.size()[0] != W.size()[0]:
        raise ValueError(
            f'the obs of x ({x.size()[0]}) does not match obs of W ({W.size()[0]})')

    obs = x.size()[0]
    in_dims = x.size()[1]
    out_dims = W.size()[1]

    x = x.view(obs, 1, in_dims)
    W = W.transpose(-2, -1)

    if b is None:
        return torch.bmm(x, W).view(obs, out_dims)
        b = b.view(1, 1, out_dims)
        return torch.baddbmm(1, b, 1, x, W).view(obs, out_dims)
Example #23
    def get_score(self, target_idx, noise_idx, input):
            - target_batch :math:`(N, E, 1+N_r)`where `N = length, E = embedding size, N_r = noise ratio`

        # flatten the following matrix
        input = input.contiguous().view(-1, input.size(-1))
        original_size = target_idx.size(
        )  # the size will be used to pack the output of indexlinear
        target_idx = target_idx.view(-1)
        noise_idx = noise_idx.view(-1, noise_idx.size(-1))

        indices = torch.cat([target_idx.unsqueeze(-1), noise_idx], dim=-1)

        # the pytorch's [] operator can't BP correctly with redundant indices
        # before version 0.2.0
        input = input.unsqueeze(1)
        target_batch = self.weight.index_select(0, indices.view(-1)).view(
            *indices.size(), -1).transpose(1, 2)
        bias = self.bias.index_select(
            0, indices.view(-1)).view_as(indices).unsqueeze(1)
        out = torch.baddbmm(1, bias, 1, input,
                            target_batch).view(*original_size, -1)
        target_score, noise_score = out[:, :, 0], out[:, :, 1:]
        return target_score, noise_score
Example #24
def gram_matrix(input_tensor):
    Compute Gram matrix
    :param input_tensor: input tensor with shape
    (batch_size, nbr_channels, height, width)
    :return: Gram matrix of y

    Ripped from: https://github.com/NVIDIA/partialconv/blob/master/models/loss.py#L17-L40
    (b, ch, h, w) = input_tensor.size()
    features = input_tensor.view(b, ch, w * h)
    features_t = features.transpose(1, 2)

    # more efficient and formal way to avoid underflow for mixed precision training
    input = torch.zeros(b, ch, ch).type(features.type())
    gram = torch.baddbmm(input,
                         alpha=1. / (ch * h * w),

    # naive way to avoid underflow for mixed precision training
    # features = features / (ch * h)
    # gram = features.bmm(features_t) / w

    # for fp32 training, it is also safe to use the following:
    # gram = features.bmm(features_t) / (ch * h * w)

    return gram
Example #25
def reproject_points(points_cam_ref, extrinsics_ref, extrinsics_tgt):
    """Reproject points in reference camera coordinate to target camera coordinate

        points_cam_ref (B, 3, H, W): points in reference camera coordinate.
        extrinsics_ref (B, 3, 4): [R, t] of reference camera.
        extrinsics_tgt (B, 3, 4): [R, t] of target_camera.

        points_cam_tgt (B, 3, H, W): points in target camera coordinate.

    B, p_dim, H, W = points_cam_ref.shape
    assert p_dim == 3, "dimension of point {} != 3".format(p_dim)

    # t + R * p where t of (B, 3, 1), R of (B, 3, 3) and p of (B, 3, H*W)
    R_ref = extrinsics_ref[..., :p_dim]
    t_ref = extrinsics_ref[..., -1:]
    points_world = torch.baddbmm(t_ref, R_ref,
                                 points_cam_ref.view(B, p_dim, -1))

    # Reproject to target:
    # R'^T * (p - t') where t' of (B, 3, 1), R' of (B, 3, 3) and p of (B, 3, H*W)
    R_tgt = extrinsics_tgt[..., :p_dim]
    t_tgt = extrinsics_tgt[..., -1:]
    points_cam_tgt = torch.bmm(R_tgt.transpose(1, 2), points_world - t_tgt)
    return points_cam_tgt.view(B, p_dim, H, W)
Example #26
    def apply_classification_weights(self, x_transformed, weights):
        Computing logits for classification

            x_transformed : feature of query_x [batch_size, n_way, k_shot, feature_dim]
            weights : prototype_weight [batch_size, n_way, dim]

        Returns : 
            logits : [batch_size, n_way, k_shot, feature_dim]
        # Reshape
        task_size, n_way, k_shot, feature_dim = x_transformed.size()
        x_transformed = x_transformed.view(task_size, -1, feature_dim)

        # Normalizing
        x_transformed = F.normalize(x_transformed,
                                    dim=x_transformed.dim() - 1,
        weights = F.normalize(weights, p=2, dim=weights.dim() - 1, eps=1e-12)

        # logits [task_size, n_way * k_shot, n_way]
        logits = self.scale_cls * torch.baddbmm(1.0, self.bias.view(
            1, 1, 1), 1.0, x_transformed, weights.transpose(1, 2))

        # Reshape : logits
        logits = logits.view(
            task_size, n_way, k_shot,
            -1)  # [batch_size, n_way, k_shot, n_way (num_classes)]

        return logits
    def forward(self, input: torch.Tensor):
        size = input.size()
        assert input.dim() == self.dim and size[1] == self.num_features
        x = input.view(size[0], self.num_groups, size[1] // self.num_groups, *size[2:])

        x = x.view(size[0], self.num_groups, -1)
        IG, d, m = x.size()
        mean = x.mean(-1, keepdim=True)
        x_mean = x - mean
        P = [torch.Tensor([]) for _ in range(self.T+1)]
        sigma = x_mean.matmul(x_mean.transpose(1, 2)) / m

        P[0] = torch.eye(d).to(x).expand(sigma.shape)
        M_zero = sigma.clone().fill_(0)
        trace_inv = torch.addcmul(M_zero, sigma, P[0]).sum((1, 2), keepdim=True).reciprocal_()
        sigma_N=torch.addcmul(M_zero, sigma, trace_inv)
        for k in range(self.T):
            P[k+1] = torch.baddbmm(1.5, P[k], -0.5, self.matrix_power3(P[k]), sigma_N)
        wm = torch.addcmul(M_zero, P[self.T], trace_inv.sqrt())
        y = wm.matmul(x_mean)
        output = y.view(size[0], self.num_groups, size[1] // self.num_groups, *size[2:])
        output = output.view_as(input)
        if self.affine:
            output = output * self.weight + self.bias
        return output
Example #28
    def forward(self, mix_hidden, query):
        #todo:这个要弄好,其实也可以直接抛弃memory来进行attention | DONE
        BATCH_SIZE = mix_hidden.size()[0]
        # assert query.size()==(BATCH_SIZE,self.hidden_size)
        # assert mix_hidden.size()[-1]==self.hidden_size
        #mix_hidden:bs,max_len,fre,hidden_size  query:bs,hidden_size
        if self.mode == 'dot':
            # mix_hidden=mix_hidden.view(-1,1,self.hidden_size)
            mix_shape = mix_hidden.size()
            mix_hidden = mix_hidden.view(BATCH_SIZE, -1, self.hidden_size)
            query = query.view(-1, self.hidden_size, 1)
            # print '\n\n',mix_hidden.requires_grad,query.requires_grad,'\n\n'
            dot = torch.baddbmm(Variable(torch.zeros(1, 1).cuda()), mix_hidden,
            energy = dot.view(BATCH_SIZE, mix_shape[1], mix_shape[2])
            mask = F.sigmoid(energy)
            return mask

        elif self.mode == 'align':
            # mix_hidden=Variable(mix_hidden)
            # query=Variable(query)
            mix_shape = mix_hidden.size()
            mix_hidden = mix_hidden.view(-1, self.hidden_size)
            mix_hidden = self.Linear_1(mix_hidden).view(
                BATCH_SIZE, -1, self.align_hidden_size)
            query = self.Linear_2(query).view(
                -1, 1, self.align_hidden_size)  #bs,1,hidden
            sum = F.tanh(mix_hidden + query)
            energy = self.Linear_3(sum.view(-1, self.align_hidden_size)).view(
                BATCH_SIZE, mix_shape[1], mix_shape[2])
            mask = F.sigmoid(energy)
            return mask
Example #29
    def forward(self, x, lengths, h_prev, target_head, compute_softmax=False):
        # x should be a numpy array of n_seq x n_batch dimensions
        b_sz = x.size(1)
        n_steps = x.size(0)
        x = Variable(x).cuda()
        emb = self.enc_drop(self.encoder(x))
        packed = pack_padded_sequence(emb, lengths)

        rnn_out, hidden = self._my_recurrent_layer(packed, h_prev)

        rnn_out_unp = pad_packed_sequence(rnn_out)
        rnn_out = self.dec_drop(rnn_out_unp[0])

        # implement the multi-headed RNN.
        W = self.decoder_W[target_head.cuda()]
        # reshape and expand b to size (batch*n_steps*vocab_size)
        b = self.decoder_b[target_head.cuda()].view(b_sz, -1, self.output_size)
        b = b.expand(b_sz, x.size(0), self.output_size)
        # output is size seq * batch_size * vocab
        dec_out = torch.baddbmm(b, rnn_out.transpose(0, 1), W).transpose(0, 1)

        if compute_softmax:
            prob_out = self.softmax(dec_out.view(-1, self.output_size)).view(
                n_steps, b_sz, self.output_size)
            prob_out = dec_out

        return prob_out, hidden
Example #30
    def forward_eval(self, x, h_prev, compute_softmax=True):
        # x should be a numpy array of n_seq x n_batch dimensions
        # In this case batch will be a single sequence.
        n_auth = self.num_output_layers
        n_steps = x.size(0)
        x = Variable(x, volatile=True).cuda()
        # No Dropout needed
        emb = self.encoder(x)
        # No need for any packing here
        packed = emb

        rnn_out, hidden = self._my_recurrent_layer(packed, h_prev)

        # implement the multi-headed RNN.
        rnn_out = rnn_out.expand(n_steps, n_auth, self.hidden_size)
        W = self.decoder_W

        # reshape and expand b to size (n_auth*n_steps*vocab_size)
        b = self.decoder_b.view(n_auth, -1, self.output_size).expand(
            n_auth, n_steps, self.output_size)

        # output is size seq * batch_size * vocab
        dec_out = torch.baddbmm(b, rnn_out.transpose(0, 1), W).transpose(0, 1)

        if compute_softmax:
            prob_out = self.softmax(dec_out.contiguous().view(
                -1, self.output_size)).view(n_steps, n_auth, self.output_size)
            prob_out = dec_out

        return prob_out, hidden
Example #31
 def forward(ctx, add_batch, batch1, batch2, alpha=1, beta=1, inplace=False):
     ctx.alpha = alpha
     ctx.beta = beta
     ctx.add_batch_size = add_batch.size()
     ctx.save_for_backward(batch1, batch2)
     output = _get_output(ctx, add_batch, inplace=inplace)
     return torch.baddbmm(alpha, add_batch, beta,
                          batch1, batch2, out=output)
Example #32
    def backward(self, grad_output):

        grad_input1 = self.input1.new(self.input1.size()).zero_()

        # if grad_output.is_cuda:
        #    self.batchgrid = self.batchgrid.cuda()
        #    grad_input1 = grad_input1.cuda()

        grad_input1 = torch.baddbmm(grad_input1, torch.transpose(grad_output.view(-1, self.height*self.width, 2), 1,2), self.batchgrid.view(-1, self.height*self.width, 3))
        return grad_input1