def VarGRUCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None, noise_in=None, noise_hidden=None): input = input.expand( 3, * input.size()) if noise_in is None else input.unsqueeze(0) * noise_in hx = hidden.expand(3, *hidden.size( )) if noise_hidden is None else hidden.unsqueeze(0) * noise_hidden gi = torch.baddbmm(b_ih.unsqueeze(1), input, w_ih) gh = torch.baddbmm(b_hh.unsqueeze(1), hx, w_hh) i_r, i_i, i_n = gi h_r, h_i, h_n = gh resetgate = torch.sigmoid(i_r + h_r) inputgate = torch.sigmoid(i_i + h_i) newgate = torch.tanh(i_n + resetgate * h_n) hy = newgate + inputgate * (hidden - newgate) return hy
def forward(self, hidden, encoder_outputs): lengths = None if type(encoder_outputs) is PackedSequence: encoder_outputs, lengths = pad(encoder_outputs, batch_first=True) else: lengths = [len(x) for x in encoder_outputs] batch_size = encoder_outputs.size()[0] attns = cuda(T.zeros(batch_size, max(lengths)), gpu_id=self.gpu_id) lengths = cuda(T.zeros(max(lengths), 1), gpu_id=self.gpu_id) if self.method == 'dot': attns = T.baddbmm(lengths, encoder_outputs, hidden.transpose(2, 1)).squeeze(2) elif self.method == 'general': attended = self.attn(encoder_outputs) attns = T.baddbmm(lengths, attended, hidden.transpose(2, 1)).squeeze(2) elif self.method == 'concat': concated = T.cat( (hidden.expand_as(encoder_outputs), encoder_outputs), 2) energy = self.attn(concated) expanded = self.other.unsqueeze(0).expand(batch_size, 1, self.hidden_size) attns = T.baddbmm(lengths, energy, expanded.transpose(2, 1)).squeeze(2) return F.softmax(attns).unsqueeze(1)
def var_lstm_cell(input: Tensor, hidden: Tuple[Tensor, Tensor], w_ih: Tensor, w_hh: Tensor, b_ih: Tensor = None, b_hh: Tensor = None, noise_in: Tensor = None, noise_hidden: Tensor = None) \ -> Tuple[Tensor, Tensor]: input = input.expand( 4, * input.size()) if noise_in is None else input.unsqueeze(0) * noise_in hx, cx = hidden hx = hx.expand( 4, * hx.size()) if noise_hidden is None else hx.unsqueeze(0) * noise_hidden gates = torch.add(torch.baddbmm(b_ih.unsqueeze(1), input, w_ih), torch.baddbmm(b_hh.unsqueeze(1), hx, w_hh)) ingate, forgetgate, cellgate, outgate = gates ingate = torch.sigmoid(ingate) forgetgate = torch.sigmoid(forgetgate) cellgate = torch.tanh(cellgate) outgate = torch.sigmoid(outgate) cy = torch.add(torch.mul(forgetgate, cx), torch.mul(ingate, cellgate)) hy = torch.mul(outgate, torch.tanh(cy)) return hy, cy
def update_stack(self, V, s, vt, ut, dt, t): # batch x k x stack dim zeromat = torch.zeros(V.shape) zeromat[:, t, :] = vt Vp = V + zeromat prod = torch.baddbmm( torch.zeros(s.unsqueeze(-1).shape), self.matrix_partial_sums[:, 0:s.shape[1], 0:s.shape[1]], s.unsqueeze(-1)) zero_vals = torch.zeros(s.shape) prod = prod.squeeze() new_strength = ut.repeat(1, s.shape[1]) - prod sp = torch.max(zero_vals, s - torch.max(zero_vals, new_strength)) #sp = torch.cat((sp, dt), dim = 1) sp[:, t] = dt[:, 0] rt = torch.zeros(self.batch_size, self.stack_dim) inner_max_vec = self.relu(torch.ones(prod.shape) - prod - sp) matrix_A = torch.min(sp, inner_max_vec).unsqueeze(1) rt = torch.baddbmm(torch.zeros(self.batch_size, 1, self.stack_dim), matrix_A, Vp).squeeze() return Vp, sp, rt
def test_baddbmm(self): rand_seed = int(get_rand_seed()) print("{} rand sed: {}".format(sys._getframe().f_code.co_name, rand_seed)) torch.manual_seed(rand_seed) for i in range(8, 12, 2): for j in range(8, 12, 2): alpha = i / 10 beta = j / 10 batches = 2 M, N, O = 23, 8, 12 x_auto_mix_a = torch.randn(batches, M, N, dtype=torch.float32, device=device) x_auto_mix_b = torch.randn(batches, N, O, dtype=torch.float32, device=device) add_auto_mix = torch.randn(batches, M, O, dtype=torch.float32, device=device) x_man_bf16_a = x_auto_mix_a.to(torch.bfloat16) x_man_bf16_b = x_auto_mix_b.to(torch.bfloat16) add_man_bf16 = add_auto_mix.to(torch.bfloat16) with AutoDNNL(True), AutoMixPrecision(False): res_man_bf16 = torch.baddbmm(add_man_bf16, x_man_bf16_a, x_man_bf16_b, beta=beta, alpha=alpha) with AutoMixPrecision(True): res_auto_mix = torch.baddbmm(add_auto_mix, x_auto_mix_a, x_auto_mix_b, beta=beta, alpha=alpha) self.assertEqual(res_auto_mix.dtype, torch.float) self.assertTrue(ipex.core.is_bf16_dil_tensor(res_auto_mix)) self.assertEqual(res_auto_mix, res_man_bf16.float())
def calculate_convolution(weight_gate_list, bias_gate_list, weight_list, bias_list, index, d_kind, adj_matrix, h): if self.gate_flag: z = torch.baddbmm(bias_gate_list[index].expand( [current_batch_size, -1, -1]), batch1=weight_gate_list[index].expand( [current_batch_size, -1, -1]), batch2=h) gate = self.gate_activity(z) else: gate = 1 relation = torch.baddbmm( bias_list[index].expand([current_batch_size, -1, -1]), batch1=weight_list[index].expand([current_batch_size, -1, -1]), batch2=h) if hasattr(self, 'dropout'): relation = self.dropout(relation) relation = self.norm_item * gate * relation relation = torch.bmm(relation, adj_matrix[:, d_kind]) return relation
def SkipConnectLSTMCell(input, hidden, hidden_skip, w_ih, w_hh, b_ih=None, b_hh=None, noise_in=None, noise_hidden=None): input = input.expand( 4, * input.size()) if noise_in is None else input.unsqueeze(0) * noise_in hx, cx = hidden hx = torch.cat([hx, hidden_skip], dim=1) hx = hx.expand( 4, * hx.size()) if noise_hidden is None else hx.unsqueeze(0) * noise_hidden gates = torch.baddbmm(b_ih.unsqueeze(1), input, w_ih) + torch.baddbmm( b_hh.unsqueeze(1), hx, w_hh) ingate, forgetgate, cellgate, outgate = gates ingate = F.sigmoid(ingate) forgetgate = F.sigmoid(forgetgate) cellgate = F.tanh(cellgate) outgate = F.sigmoid(outgate) cy = (forgetgate * cx) + (ingate * cellgate) hy = outgate * F.tanh(cy) return hy, cy
def reordered(self, q, k, v, mask, head_m, **kw): cfg = self.cfg yo = self.get_y_opts(**kw) alpha = 1.0 if cfg.scale: alpha /= float(v.size(-1))**0.5 if cfg.scale_by_inv: alpha /= float(self.lay_i + 1) b, n, n_q, d = q.size() _, _, n_k, _ = k.size() a = torch.empty(b * n, n_q, n_k, dtype=torch.float32, device=q.device) if is_amp: with autocast(enabled=False): q, k = q.reshape(-1, n_q, d), k.transpose(-1, -2).reshape(-1, d, n_k) a = torch.baddbmm(a, q.float(), k.float(), beta=0, alpha=alpha) a = a.reshape(b, n, n_q, n_k) else: q, k = q.reshape(-1, n_q, d), k.transpose(-1, -2).reshape(-1, d, n_k) a = torch.baddbmm(a, q.float(), k.float(), beta=0, alpha=alpha) a = a.reshape(b, n, n_q, n_k) if not self.is_cross: m = self.bias[:, :, n_k - n_q:n_k, :n_k].bool() a = torch.where(m, a, self.bias_m.to(a.dtype)) if mask is not None: a = a + mask a = self.drop_attn(F.softmax(a, dim=-1).type(v.dtype)) if head_m is not None: a = a * head_m y = (torch.matmul(a, v), ) if yo.attn: y += (a, ) return y
def bpdist(feature, data_format='NCW'): """Compute pairwise (square) distances of features. Based on $(x-y)^2=x^2+y^2-2xy$. Args: feature (torch.Tensor): (batch_size, channels, num_inst) data_format (str): the format of features. [NCW/NWC] Returns: distance (torch.Tensor): (batch_size, num_inst, num_inst) Notes: This method returns square distances, and is optimized for lower memory and faster speed. Square sum is more efficient than gather diagonal from inner product. The result is somehow inaccurate compared to directly using $(x-y)^2$. """ assert data_format in ('NCW', 'NWC') if data_format == 'NCW': square_sum = torch.sum(feature**2, 1, keepdim=True) square_sum = square_sum.transpose(1, 2) + square_sum distance = torch.baddbmm(square_sum, feature.transpose(1, 2), feature, alpha=-2.0) else: square_sum = torch.sum(feature**2, 2, keepdim=True) square_sum = square_sum.transpose(1, 2) + square_sum distance = torch.baddbmm(square_sum, feature, feature.transpose(1, 2), alpha=-2.0) return distance
def SkipConnectGRUCell(input, hidden, hidden_skip, w_ih, w_hh, b_ih=None, b_hh=None, noise_in=None, noise_hidden=None): input = input.expand( 3, * input.size()) if noise_in is None else input.unsqueeze(0) * noise_in hx = torch.cat([hidden, hidden_skip], dim=1) hx = hx.expand( 3, * hx.size()) if noise_hidden is None else hx.unsqueeze(0) * noise_hidden gi = torch.baddbmm(b_ih.unsqueeze(1), input, w_ih) gh = torch.baddbmm(b_hh.unsqueeze(1), hx, w_hh) i_r, i_i, i_n = gi h_r, h_i, h_n = gh resetgate = F.sigmoid(i_r + h_r) inputgate = F.sigmoid(i_i + h_i) newgate = F.tanh(i_n + resetgate * h_n) hy = newgate + inputgate * (hidden - newgate) return hy
def bpdist2(feature1, feature2, data_format='NCW'): """Compute pairwise (square) distances of features. Args: feature1 (torch.Tensor): (batch_size, channels, num_inst1) feature2 (torch.Tensor): (batch_size, channels, num_inst2) data_format (str): the format of features. [NCW/NWC] Returns: distance (torch.Tensor): (batch_size, num_inst1, num_inst2) """ assert data_format in ('NCW', 'NWC') if data_format == 'NCW': square_sum1 = torch.sum(feature1**2, 1, keepdim=True) square_sum2 = torch.sum(feature2**2, 1, keepdim=True) square_sum = square_sum1.transpose(1, 2) + square_sum2 distance = torch.baddbmm(square_sum, feature1.transpose(1, 2), feature2, alpha=-2.0) else: square_sum1 = torch.sum(feature1**2, 2, keepdim=True) square_sum2 = torch.sum(feature2**2, 2, keepdim=True) square_sum = square_sum1 + square_sum2.transpose(1, 2) distance = torch.baddbmm(square_sum, feature1, feature2.transpose(1, 2), alpha=-2.0) return distance
def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None): # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM) bsz, num_heads, q_seq_len, dk = query.size() _, _, k_seq_len, _ = key.size() # Preallocate attn_weights for `baddbmm` attn_weights = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=torch.float32, device=query.device) # Compute Scale Factor scale_factor = 1.0 if self.scale_attn_weights: scale_factor /= float(value.size(-1)) ** 0.5 if self.scale_attn_by_inverse_layer_idx: scale_factor /= float(self.layer_idx + 1) # Upcast (turn off autocast) and reorder (Scale K by 1 / root(dk)) if is_amp_available: with autocast(enabled=False): q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) else: q, k = query.reshape(-1, q_seq_len, dk), key.transpose(-1, -2).reshape(-1, dk, k_seq_len) attn_weights = torch.baddbmm(attn_weights, q.float(), k.float(), beta=0, alpha=scale_factor) attn_weights = attn_weights.reshape(bsz, num_heads, q_seq_len, k_seq_len) if not self.is_cross_attention: # if only "normal" attention layer implements causal mask query_length, key_length = query.size(-2), key.size(-2) causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() mask_value = torch.finfo(attn_weights.dtype).min # Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`. # Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device` mask_value = torch.tensor(mask_value, dtype=attn_weights.dtype).to(attn_weights.device) attn_weights = torch.where(causal_mask, attn_weights, mask_value) if attention_mask is not None: # Apply the attention mask attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) # Downcast (if necessary) back to V's dtype (if in mixed-precision) -- No-Op if otherwise if attn_weights.dtype != torch.float32: raise RuntimeError("Error with upcasting, attn_weights does not have dtype torch.float32") attn_weights = attn_weights.type(value.dtype) attn_weights = self.attn_dropout(attn_weights) # Mask heads if we want to if head_mask is not None: attn_weights = attn_weights * head_mask attn_output = torch.matmul(attn_weights, value) return attn_output, attn_weights
def forward(self, z): res = torch.baddbmm(self.bias.repeat(z.shape[0], 1).unsqueeze(2), self.w.unsqueeze(0).repeat(z.shape[0], 1, 1).transpose(1,2), z.unsqueeze(2)) #+ b.repeat(z.shape[0]). res = torch.baddbmm(z.unsqueeze(2), self.u.unsqueeze(0).repeat(z.shape[0],1,1), torch.tanh(res)) res = res.squeeze() self.current_det = self.logdet_jacobian(z) return res
def logdet_jacobian(self, z): det = torch.baddbmm(self.bias.repeat(z.shape[0], 1).unsqueeze(2), self.w.unsqueeze(0).repeat(z.shape[0], 1, 1).transpose(1,2), z.unsqueeze(2)) #+ b.repeat(z.shape[0]). det = torch.bmm(1 - torch.pow(torch.tanh(det), 2), self.w.unsqueeze(0).repeat(z.shape[0], 1, 1).transpose(1,2)) det = torch.baddbmm(torch.ones(z.shape[0], 1, 1), det, self.u.unsqueeze(0).repeat(z.shape[0],1,1)) det = det.squeeze().abs() return det
def orthogonal(points, calibrations, transforms=None): ''' Compute the orthogonal projections of 3D points into the image plane by given projection matrix :param points: [B, 3, N] Tensor of 3D points :param calibrations: [B, 4, 4] Tensor of projection matrix :param transforms: [B, 2, 3] Tensor of image transform matrix :return: xyz: [B, 3, N] Tensor of xyz coordinates in the image plane ''' rot = calibrations[:, :3, :3] trans = calibrations[:, :3, 3:4] pts = torch.baddbmm(trans, rot, points) # [B, 3, N] if transforms is not None: scale = transforms[:2, :2] shift = transforms[:2, 2:3] pts[:, :2, :] = torch.baddbmm(shift, scale, pts[:, :2, :]) return pts
def gram_matrix(input_tensor, device=torch.device("cuda")): """ Compute Gram matrix :param input_tensor: input tensor with shape (batch_size, nbr_channels, height, width) :return: Gram matrix of y """ (b, ch, h, w) = input_tensor.size() features = input_tensor.view(b, ch, w * h) features_t = features.transpose(1, 2) # more efficient and formal way to avoid underflow for mixed precision training input = torch.zeros(b, ch, ch).type(features.type()).to(device) gram = torch.baddbmm(input, features, features_t, beta=0, alpha=1. / (ch * h * w), out=None) # naive way to avoid underflow for mixed precision training # features = features / (ch * h) # gram = features.bmm(features_t) / w # for fp32 training, it is also safe to use the following: # gram = features.bmm(features_t) / (ch * h * w) return gram
def backward(self, grad_output): grad_input1 = torch.zeros(self.input1.size()) grad_input1 = torch.baddbmm(grad_input1, torch.transpose(grad_output.view(-1, self.height*self.width*self.depth, 3), 1,2), self.batchgrid.view(-1, self.height*self.width*self.depth,4)) #print(grad_input1) return grad_input1
def batched_all_pairs_squared_l2_dist( a: FloatTensorType, b: FloatTensorType, ) -> FloatTensorType: """For each batch, return the squared L2 distance between each pair of vectors Let A and B be tensors of shape NxM_AxD and NxM_BxD, each containing N*M_A and N*M_B vectors of dimension D grouped in N batches of size M_A and M_B. For each batch, for each vector of A and each vector of B, return the sum of the squares of the differences of their components. """ num_chunks, num_a, dim = match_shape(a, -1, -1, -1) num_b = match_shape(b, num_chunks, -1, dim) a_squared = a.norm(dim=-1).pow(2) b_squared = b.norm(dim=-1).pow(2) # Calculate res_i,k = sum_j((a_i,j - b_k,j)^2) for each i and k as # sum_j(a_i,j^2) - 2 sum_j(a_i,j b_k,j) + sum_j(b_k,j^2), by using a matrix # multiplication for the ab part, adding the b^2 as part of the baddbmm call # and the a^2 afterwards. res = torch.baddbmm( b_squared.unsqueeze(-2), a, b.transpose(-2, -1), alpha=-2 ).add_(a_squared.unsqueeze(-1)) match_shape(res, num_chunks, num_a, num_b) return res
def forward(self, left_in, right_in): left, right = bundle(left_in), bundle(right_in) lstm_gates = self.left(left.h) lstm_gates += self.right(right.h) # Core composition # Retrieve hidden state from left_in and feed it into weight matrix cell_inp = [] hidden_dim = self.dim * self.dim h = left.h.contiguous().view(-1, self.dim, self.dim) cell_inp = torch.matmul(self.weight, h) cell_inp = F.tanh(torch.add(cell_inp, self.b1)) # Retrieve hidden state from right_in h = right.h.contiguous().view(-1, self.dim, self.dim) cell_inp = F.tanh(torch.baddbmm(self.b2, cell_inp, h)) cell_inp = cell_inp.view(-1, hidden_dim) out = unbundle( treelstmtensor(left.c, right.c, lstm_gates, cell_inp, training=self.training)) return out
def set_full_solution_batched(self): # Combines essential boundary conditions and solution of equation system self.full_solution_torch = torch.baddbmm( self.mesh.essential_solution_vector_torch.unsqueeze(0).expand( self.solution_torch.shape[0], -1, -1), self.mesh.scatter_matrix_torch.unsqueeze(0).expand( self.solution_torch.shape[0], -1, -1), self.solution_torch)
def prediction_layer(self, S_i, u_i, instg_S_i, q_j, qb_j, for_pred=False): if for_pred: q_j = q_j.squeeze() qb_j = qb_j.squeeze() # long-term res = u_i.mm(q_j.t()) + qb_j # short-term res += instg_S_i.mm(q_j.t()) # item-item rel_score = torch.matmul(S_i, q_j.t().unsqueeze(0)) rel_score = torch.sum(rel_score, dim=1) res += rel_score else: # long-term res = torch.baddbmm(qb_j, q_j, u_i.unsqueeze(2)).squeeze() # short-term res += torch.bmm(instg_S_i.unsqueeze(1), q_j.permute(0, 2, 1)).squeeze() # item-item rel_score = S_i.bmm(q_j.permute(0, 2, 1)) rel_score = torch.sum(rel_score, dim=1) res += rel_score return res
def batch_linear(x, W, b=None): """Computes y_i = x_i W_i + b_i where i is each observation index. This is similar to `torch.nn.functional.linear`, but a version that supports a different W for each observation. x: has shape [obs, in_dims] W: has shape [obs, out_dims, in_dims] b: has shape [out_dims] """ if x.size()[1] != W.size()[-1]: raise ValueError( f'the in_dim of x ({x.size()[1]}) does not match in_dim of W ({W.size()[-1]})') if x.size()[0] != W.size()[0]: raise ValueError( f'the obs of x ({x.size()[0]}) does not match obs of W ({W.size()[0]})') obs = x.size()[0] in_dims = x.size()[1] out_dims = W.size()[1] x = x.view(obs, 1, in_dims) W = W.transpose(-2, -1) if b is None: return torch.bmm(x, W).view(obs, out_dims) else: b = b.view(1, 1, out_dims) return torch.baddbmm(1, b, 1, x, W).view(obs, out_dims)
def get_score(self, target_idx, noise_idx, input): """ Shape: - target_batch :math:`(N, E, 1+N_r)`where `N = length, E = embedding size, N_r = noise ratio` """ # flatten the following matrix input = input.contiguous().view(-1, input.size(-1)) original_size = target_idx.size( ) # the size will be used to pack the output of indexlinear target_idx = target_idx.view(-1) noise_idx = noise_idx.view(-1, noise_idx.size(-1)) indices = torch.cat([target_idx.unsqueeze(-1), noise_idx], dim=-1) # the pytorch's [] operator can't BP correctly with redundant indices # before version 0.2.0 input = input.unsqueeze(1) target_batch = self.weight.index_select(0, indices.view(-1)).view( *indices.size(), -1).transpose(1, 2) bias = self.bias.index_select( 0, indices.view(-1)).view_as(indices).unsqueeze(1) out = torch.baddbmm(1, bias, 1, input, target_batch).view(*original_size, -1) target_score, noise_score = out[:, :, 0], out[:, :, 1:] return target_score, noise_score
def gram_matrix(input_tensor): """ Compute Gram matrix :param input_tensor: input tensor with shape (batch_size, nbr_channels, height, width) :return: Gram matrix of y Ripped from: https://github.com/NVIDIA/partialconv/blob/master/models/loss.py#L17-L40 """ (b, ch, h, w) = input_tensor.size() features = input_tensor.view(b, ch, w * h) features_t = features.transpose(1, 2) # more efficient and formal way to avoid underflow for mixed precision training input = torch.zeros(b, ch, ch).type(features.type()) gram = torch.baddbmm(input, features, features_t, beta=0, alpha=1. / (ch * h * w), out=None) # naive way to avoid underflow for mixed precision training # features = features / (ch * h) # gram = features.bmm(features_t) / w # for fp32 training, it is also safe to use the following: # gram = features.bmm(features_t) / (ch * h * w) return gram
def reproject_points(points_cam_ref, extrinsics_ref, extrinsics_tgt): """Reproject points in reference camera coordinate to target camera coordinate Args: points_cam_ref (B, 3, H, W): points in reference camera coordinate. extrinsics_ref (B, 3, 4): [R, t] of reference camera. extrinsics_tgt (B, 3, 4): [R, t] of target_camera. Returns: points_cam_tgt (B, 3, H, W): points in target camera coordinate. """ B, p_dim, H, W = points_cam_ref.shape assert p_dim == 3, "dimension of point {} != 3".format(p_dim) # t + R * p where t of (B, 3, 1), R of (B, 3, 3) and p of (B, 3, H*W) R_ref = extrinsics_ref[..., :p_dim] t_ref = extrinsics_ref[..., -1:] points_world = torch.baddbmm(t_ref, R_ref, points_cam_ref.view(B, p_dim, -1)) # Reproject to target: # R'^T * (p - t') where t' of (B, 3, 1), R' of (B, 3, 3) and p of (B, 3, H*W) R_tgt = extrinsics_tgt[..., :p_dim] t_tgt = extrinsics_tgt[..., -1:] points_cam_tgt = torch.bmm(R_tgt.transpose(1, 2), points_world - t_tgt) return points_cam_tgt.view(B, p_dim, H, W)
def apply_classification_weights(self, x_transformed, weights): """ Computing logits for classification Args: x_transformed : feature of query_x [batch_size, n_way, k_shot, feature_dim] weights : prototype_weight [batch_size, n_way, dim] Returns : logits : [batch_size, n_way, k_shot, feature_dim] """ # Reshape task_size, n_way, k_shot, feature_dim = x_transformed.size() x_transformed = x_transformed.view(task_size, -1, feature_dim) # Normalizing x_transformed = F.normalize(x_transformed, p=2, dim=x_transformed.dim() - 1, eps=1e-12) weights = F.normalize(weights, p=2, dim=weights.dim() - 1, eps=1e-12) # logits [task_size, n_way * k_shot, n_way] logits = self.scale_cls * torch.baddbmm(1.0, self.bias.view( 1, 1, 1), 1.0, x_transformed, weights.transpose(1, 2)) # Reshape : logits logits = logits.view( task_size, n_way, k_shot, -1) # [batch_size, n_way, k_shot, n_way (num_classes)] return logits
def forward(self, input: torch.Tensor): size = input.size() assert input.dim() == self.dim and size[1] == self.num_features x = input.view(size[0], self.num_groups, size[1] // self.num_groups, *size[2:]) x = x.view(size[0], self.num_groups, -1) IG, d, m = x.size() mean = x.mean(-1, keepdim=True) x_mean = x - mean P = [torch.Tensor([]) for _ in range(self.T+1)] sigma = x_mean.matmul(x_mean.transpose(1, 2)) / m P[0] = torch.eye(d).to(x).expand(sigma.shape) M_zero = sigma.clone().fill_(0) trace_inv = torch.addcmul(M_zero, sigma, P[0]).sum((1, 2), keepdim=True).reciprocal_() sigma_N=torch.addcmul(M_zero, sigma, trace_inv) for k in range(self.T): P[k+1] = torch.baddbmm(1.5, P[k], -0.5, self.matrix_power3(P[k]), sigma_N) wm = torch.addcmul(M_zero, P[self.T], trace_inv.sqrt()) y = wm.matmul(x_mean) output = y.view(size[0], self.num_groups, size[1] // self.num_groups, *size[2:]) output = output.view_as(input) if self.affine: output = output * self.weight + self.bias return output
def forward(self, mix_hidden, query): #todo:这个要弄好,其实也可以直接抛弃memory来进行attention | DONE BATCH_SIZE = mix_hidden.size()[0] # assert query.size()==(BATCH_SIZE,self.hidden_size) # assert mix_hidden.size()[-1]==self.hidden_size #mix_hidden:bs,max_len,fre,hidden_size query:bs,hidden_size if self.mode == 'dot': # mix_hidden=mix_hidden.view(-1,1,self.hidden_size) mix_shape = mix_hidden.size() mix_hidden = mix_hidden.view(BATCH_SIZE, -1, self.hidden_size) query = query.view(-1, self.hidden_size, 1) # print '\n\n',mix_hidden.requires_grad,query.requires_grad,'\n\n' dot = torch.baddbmm(Variable(torch.zeros(1, 1).cuda()), mix_hidden, query) energy = dot.view(BATCH_SIZE, mix_shape[1], mix_shape[2]) mask = F.sigmoid(energy) return mask elif self.mode == 'align': # mix_hidden=Variable(mix_hidden) # query=Variable(query) mix_shape = mix_hidden.size() mix_hidden = mix_hidden.view(-1, self.hidden_size) mix_hidden = self.Linear_1(mix_hidden).view( BATCH_SIZE, -1, self.align_hidden_size) query = self.Linear_2(query).view( -1, 1, self.align_hidden_size) #bs,1,hidden sum = F.tanh(mix_hidden + query) #TODO:从这里开始做起 energy = self.Linear_3(sum.view(-1, self.align_hidden_size)).view( BATCH_SIZE, mix_shape[1], mix_shape[2]) mask = F.sigmoid(energy) return mask
def forward(self, x, lengths, h_prev, target_head, compute_softmax=False): # x should be a numpy array of n_seq x n_batch dimensions b_sz = x.size(1) n_steps = x.size(0) x = Variable(x).cuda() emb = self.enc_drop(self.encoder(x)) packed = pack_padded_sequence(emb, lengths) rnn_out, hidden = self._my_recurrent_layer(packed, h_prev) rnn_out_unp = pad_packed_sequence(rnn_out) rnn_out = self.dec_drop(rnn_out_unp[0]) # implement the multi-headed RNN. W = self.decoder_W[target_head.cuda()] # reshape and expand b to size (batch*n_steps*vocab_size) b = self.decoder_b[target_head.cuda()].view(b_sz, -1, self.output_size) b = b.expand(b_sz, x.size(0), self.output_size) # output is size seq * batch_size * vocab dec_out = torch.baddbmm(b, rnn_out.transpose(0, 1), W).transpose(0, 1) if compute_softmax: prob_out = self.softmax(dec_out.view(-1, self.output_size)).view( n_steps, b_sz, self.output_size) else: prob_out = dec_out return prob_out, hidden
def forward_eval(self, x, h_prev, compute_softmax=True): # x should be a numpy array of n_seq x n_batch dimensions # In this case batch will be a single sequence. n_auth = self.num_output_layers n_steps = x.size(0) x = Variable(x, volatile=True).cuda() # No Dropout needed emb = self.encoder(x) # No need for any packing here packed = emb rnn_out, hidden = self._my_recurrent_layer(packed, h_prev) # implement the multi-headed RNN. rnn_out = rnn_out.expand(n_steps, n_auth, self.hidden_size) W = self.decoder_W # reshape and expand b to size (n_auth*n_steps*vocab_size) b = self.decoder_b.view(n_auth, -1, self.output_size).expand( n_auth, n_steps, self.output_size) # output is size seq * batch_size * vocab dec_out = torch.baddbmm(b, rnn_out.transpose(0, 1), W).transpose(0, 1) if compute_softmax: prob_out = self.softmax(dec_out.contiguous().view( -1, self.output_size)).view(n_steps, n_auth, self.output_size) else: prob_out = dec_out return prob_out, hidden
def forward(ctx, add_batch, batch1, batch2, alpha=1, beta=1, inplace=False): ctx.alpha = alpha ctx.beta = beta ctx.add_batch_size = add_batch.size() ctx.save_for_backward(batch1, batch2) output = _get_output(ctx, add_batch, inplace=inplace) return torch.baddbmm(alpha, add_batch, beta, batch1, batch2, out=output)
def backward(self, grad_output): grad_input1 = self.input1.new(self.input1.size()).zero_() # if grad_output.is_cuda: # self.batchgrid = self.batchgrid.cuda() # grad_input1 = grad_input1.cuda() grad_input1 = torch.baddbmm(grad_input1, torch.transpose(grad_output.view(-1, self.height*self.width, 2), 1,2), self.batchgrid.view(-1, self.height*self.width, 3)) return grad_input1