def bilinear_naive(input1, input2, weight, bias=None, conjugate=True): r"""Applies a complex bilinear transformation to the incoming complex data: :math:`y = x^(T/H) W z + b`. """ n_out = weight.shape[0] ww = torch.cat([weight.real, weight.imag], dim=0) a, b = input1.real, input1.imag u, v = input2.real, input2.imag au, av = F.bilinear(a, u, ww, bias=None), F.bilinear(a, v, ww, bias=None) bu, bv = F.bilinear(b, u, ww, bias=None), F.bilinear(b, v, ww, bias=None) if conjugate: pp, qq = au + bv, av - bu else: pp, qq = au - bv, av + bu repp, impp = pp[..., :n_out], pp[..., n_out:] reqq, imqq = qq[..., :n_out], qq[..., n_out:] output = Cplx(repp - imqq, impp + reqq) if bias is not None: output += bias return output
def forward(self, input_left, input_right): """ Args: input_left: Tensor the left input tensor with shape = [batch1, batch2, ..., left_features] input_right: Tensor the right input tensor with shape = [batch1, batch2, ..., right_features] Returns: """ batch_size = input_left.size()[:-1] batch = int(np.prod(batch_size)) # convert left and right input to matrices [batch, left_features], [batch, right_features] input_left = input_left.view(batch, self.left_features) input_right = input_right.view(batch, self.right_features) # output [batch, out_features] output = F.bilinear(input_left, input_right, self.U, self.bias) output = output + F.linear(input_left, self.weight_left, None) + F.linear( input_right, self.weight_right, None) # convert back to [batch1, batch2, ..., out_features] return output.view(batch_size + (self.out_features, ))
def forward(self, input_left, input_right): ''' Args: input_left: Tensor the left input tensor with shape = [batch1, batch2, ..., left_features] input_right: Tensor the right input tensor with shape = [batch1, batch2, ..., right_features] Returns: ''' # left_size torch.Size([16, 24, 128] left_size = input_left.size() right_size = input_right.size() assert left_size[:-1] == right_size[:-1], \ "batch size of left and right inputs mis-match: (%s, %s)" % (left_size[:-1], right_size[:-1]) batch_size = int(np.prod(left_size[:-1])) # batch_size =384 = (16*24) # convert left and right input to matrices [batch_size, left_features], [batch_size, right_features] # input_left = torch.Size([384, 128]) input_left = input_left.view(batch_size, self.left_features) input_right = input_right.view(batch_size, self.right_features) # output [batch_size, out_features] # y = x_1*A*x_2 + b output = F.bilinear(input_left, input_right, self.U, self.bias) output = output + F.linear(input_left, self.W_l, None) + F.linear( input_right, self.W_r, None) # convert back to [batch1, batch2, ..., out_features] return output.view(left_size[:-1] + (self.out_features, ))
def forward(self, input_left, input_right): ''' Args: input_left: Tensor the left input tensor with shape = [batch1, batch2, ..., left_features] input_right: Tensor the right input tensor with shape = [batch1, batch2, ..., right_features] Returns: ''' left_size = input_left.size() right_size = input_right.size() assert left_size[:-1] == right_size[:-1], \ "batch size of left and right inputs mis-match: (%s, %s)" % (left_size[:-1], right_size[:-1]) batch = int(np.prod(left_size[:-1])) # convert left and right input to matrices [batch, left_features], [batch, right_features] input_left = input_left.view(batch, self.left_features) input_right = input_right.view(batch, self.right_features) # output [batch, out_features] output = F.bilinear(input_left, input_right, self.U, self.bias) output = output + F.linear(input_left, self.W_l, None) + F.linear( input_right, self.W_r, None) # convert back to [batch1, batch2, ..., out_features] return output.view(left_size[:-1] + (self.out_features, ))
def forward(self, input1, input2, pgn_vector): weight = torch.matmul(pgn_vector, self.weight.view(self.in_params, -1)).view( self.out_features, self.in1_features, self.in2_features) bias = None if self.bias is not None: bias = torch.matmul(pgn_vector, self.bias) return F.bilinear(input1, input2, weight, bias)
def forward(self, input1, input2): mu = super().forward(input1, input2) if not self.training: return mu s2 = F.bilinear(input1.real * input1.real + input1.imag * input1.imag, input2.real * input2.real + input2.imag * input2.imag, torch.exp(self.log_sigma2), None) return mu + cplx.randn_like(s2) * torch.sqrt(torch.clamp(s2, 1e-8))
def forward(self, input1, input2): mu = super().forward(input1, input2) if not self.training: return mu s2 = F.bilinear(input1.real * input1.real + input1.imag * input1.imag, input2.real * input2.real + input2.imag * input2.imag, torch.exp(self.log_sigma2), None) noise = Cplx(*map(torch.randn_like, (s2, s2))) / sqrt(2) return mu + noise * torch.sqrt(torch.clamp(s2, 1e-8))
def forward(self, x1: Tensor, x2: Tensor): if self.bias_x: x1 = torch.cat((x1, torch.ones_like(x1[..., :1])), -1) if self.bias_y: x2 = torch.cat((x2, torch.ones_like(x2[..., :1])), -1) if self.expand: # [batch_size, n_out, seq_len, seq_len] s = torch.einsum('bxi,oij,byj->boxy', x1, self.weight, x2) return s # [batch_size, n_out, seq_len] return F.bilinear(x1, x2, self.weight, None)
def forward(self, input1, input2): r"""Forward pass of the SGVB method for a bilinear layer. Straightforward generalization of the local reparameterization trick. """ mu = super().forward(input1, input2) if not self.training: return mu s2 = F.bilinear(input1 * input1, input2 * input2, torch.exp(self.log_sigma2), None) # .normal reports a grad-fn, but weirdly does not pass grads! return mu + torch.randn_like(s2) * torch.sqrt(torch.clamp(s2, 1e-8))
def forward(self, input_left, input_right): left_size = input_left.size() right_size = input_right.size() batch = int(np.prod(left_size[:-1])) input_left = input_left.view(batch, self.left_features) input_right = input_right.view(batch, self.right_features) output = F.bilinear(input_left, input_right, self.U, self.bias) output = output + F.linear(input_left, self.W_l, None) + F.linear( input_right, self.W_r, None) return output.view(left_size[:-1] + (self.out_features, ))
def forward(self, tensora, tensorb): """ YVec = aHVec @ W1 @ bVVec + W2 @ aHVec + W3 @ bVVec Args: tensora: shape = [dim1, dim2, ..., tensora_dim] tensorb: shape = [dim1, dim2, ..., tensorb_dim] """ dims_prev = tensora.size()[:-1] dummpy_batch = int(np.prod(dims_prev)) tensora = tensora.reshape(dummpy_batch, self.tensora_dim) tensorb = tensorb.reshape(dummpy_batch, self.tensorb_dim) cross = F.bilinear(tensora, tensorb, self.U, self.bias) partial_A = F.linear(tensora, self.tensora_weight) # x A^T partial_B = F.linear(tensorb, self.tensorb_weight) # x A^T out = cross + partial_A + partial_B # [dummpy_batch, outputs_dim] out.reshape(dims_prev + (self.outputs_dim, )) return out
def bilinear_cat(input1, input2, weight, bias=None, conjugate=True): # [n_out, n_in1, n_in2] -> [2 * n_out, 2 * n_in1, 2 * n_in2] U, V = weight.real, weight.imag UV = torch.cat([U, -V], dim=2) VU = torch.cat([V, U], dim=2) if conjugate: ww = torch.cat( [torch.cat([UV, VU], dim=1), torch.cat([VU, -UV], dim=1)], dim=0) else: ww = torch.cat( [torch.cat([UV, -VU], dim=1), torch.cat([VU, UV], dim=1)], dim=0) x1 = to_concatenated_real(input1, dim=-1) x2 = to_concatenated_real(input2, dim=-1) output = from_concatenated_real(F.bilinear(x1, x2, ww, None)) if bias is not None: output += bias return output
def forward(self, in1, in2): us = self.left_singular * self.diag usv = torch.matmul(us, self.right_singular) return F.bilinear(in1, in2, weight=usv)
def forward(self, input1, input2): dir_ = self.direction direction = dir_.div(dir_.pow(2).sum(1).sum(1).sqrt()[:, N_, N_]) weight = self.scale[:, N_, N_].mul(direction) return F.bilinear(input1, input2, weight, self.bias)
def test_bilinear(): A = torch.randn(3,5,4) l = torch.randn(2,5) r = torch.randn(2,4) assert (utils.bilinear(l, A, r) == F.bilinear(l, r, A)).all()
def cpc_loss(self, gru_input_feats, gru_output_feats, feats_len): zt_feats = gru_input_feats.contiguous().view(gru_input_feats.size(0), gru_input_feats.size(1), -1) ct = gru_output_feats.contiguous().view(gru_output_feats.size(0), gru_output_feats.size(1), -1) zt_length = feats_len tot_loss = 0 nb_examples = 0 lossK = {} # key=k, value=(tot_loss_k, nb_example_k) nbErrK = {} # key=k, value=(tot_err_k, nb_example_k) # change from BxTx(FxC) to TxBxF zt_feats = zt_feats.permute(1, 0, 2) ct = ct.permute(1, 0, 2) for b in range(zt_length.size(0)): seq_len = zt_length[b].item() # compute indices matK = np.arange(self.k + 1)[:, np.newaxis] + np.arange(0, seq_len) # example: # ct_i (0 1 2 3 4 5 6) # zt_i k0 (1 2 3 4 5 6 7) # zt_i k1 (2 3 4 5 6 7 8) # ... noise = min(self.N, seq_len - 1) noiseC_ind = np.arange(seq_len * noise) // noise # if noise = 3, produce (0 0 0 1 1 1 2 2 2 ...) for k in range(self.k_step, self.k, self.k_step): z_ind = matK[k][matK[k] < seq_len] # for example if k=1, z_ind = (2 3 4 5 ... seq_len-1) in_feats_z = zt_feats[z_ind, b] # then select the zt_feats corresponding to thoses indices in_feats_c = ct[matK[0][:in_feats_z.size(0)], b] # and the ct_feats correesponding to the first line of matK # limited to the number of values in z_feats # then we wants to learn W as f(x_1, c_1) = exp(z_1.T W_1 c_1) noiseInd = np.zeros((in_feats_z.size(0), noise)) for i, z in enumerate(z_ind): rand = np.random.permutation(seq_len) orig = rand[rand != z] rand = orig[orig < (z + self.n_around)] rand = rand[rand > (z - self.n_around)] if (rand.shape[0] >= noise): n_indices = rand[:noise] else: n_around = noise + 1 rand = orig[orig < (z + n_around)] rand = rand[rand > (z - n_around)] n_indices = rand[:noise] # taking random indices (different to the one of the posit. # z_feat, limited to the number of noise that we want noiseInd[i] = n_indices # for each value of z we have noise random indices # in noiseInd matrix # noiseC_ind contains (0 0 0 ...) # noise_Ind contains (rand(seq_len)!=z rand(seq_len)!=z ...) noise_feats_z = zt_feats[ noiseInd.reshape(in_feats_z.size(0) * noise), b] noise_feats_c = ct[noiseC_ind[:in_feats_z.size(0) * noise], b] fxt_all = self.W[k]( torch.cat((in_feats_z, noise_feats_z), 0), torch.cat((in_feats_c, noise_feats_c), 0), ) f_x_t_k = fxt_all[:in_feats_z.size(0)] # loss = -(torch.log(f_x_t_k.exp() / fxt_all.exp().sum()).sum()) loss = -(f_x_t_k - torch.logsumexp(fxt_all, dim=0)).sum() slen = in_feats_z.size(0) if self.cpc_compute_kcer: # classify each elem of the sequence to compute the cer nbErr = 0 for pred in range(slen): in_feats = (in_feats_c[pred], in_feats_z[pred]) offset = pred * noise noise_f = ( noise_feats_c[offset:offset + noise], noise_feats_z[offset:offset + noise], ) f_x_t_n = F.bilinear( torch.cat((in_feats[1].unsqueeze(0), noise_f[1]), 0), torch.cat((in_feats[0].unsqueeze(0), noise_f[0]), 0), self.W[k].weight, self.W[k].bias, ) probs = f_x_t_n.exp() / f_x_t_n.exp().sum() nbErr += probs.argmax() != 0 tot_loss += loss nb_examples += f_x_t_k.size(0) if k not in lossK: lossK[k] = (loss, f_x_t_k.size(0)) if self.cpc_compute_kcer: nbErrK[k] = (nbErr, f_x_t_k.size(0)) else: (tot_loss_k, nb_examples_k) = lossK[k] nb_examples_k += f_x_t_k.size(0) if self.cpc_compute_kcer: (tot_errors, nbEx) = nbErrK[k] tot_errors += nbErr nbErrK[k] = (tot_errors, nb_examples_k) tot_loss_k += loss lossK[k] = (tot_loss_k, nb_examples_k) if self.reduction == "sum": tot_loss = 0 for k in lossK.keys(): (tot_loss_k, nb_examples_k) = lossK[k] tot_loss += tot_loss_k / nb_examples_k else: tot_loss = 0 nb_examples = 0 for k in lossK.keys(): (tot_loss_k, nb_examples_k) = lossK[k] tot_loss += tot_loss_k nb_examples += nb_examples_k tot_loss /= nb_examples details = {} details["loss"] = tot_loss for k in lossK.keys(): (tot_loss_k, nb_examples_k) = lossK[k] if self.cpc_compute_kcer: (nbErr, nbEx) = nbErrK[k] details["cer_k" + str(k + 1)] = (torch.tensor(nbErr / nbEx) * 100) if self.loss_details: details["loss_k" + str(k + 1)] = tot_loss_k / nb_examples_k if tot_loss.item() == float("inf") or tot_loss.item() == float("-inf"): print("Inf loss !!") return tot_loss, details
def forward(self, input1, input2, params=None): if params is None: params = OrderedDict(self.named_parameters()) bias = params.get('bias', None) return F.bilinear(input1, input2, params['weight'], bias)
def forward(self, input1, input2): return F.bilinear(input1, input2, self.weight_masked, self.bias)
def forward(self, my_params, other_params): return F.bilinear(other_params[None, :], my_params[None, :], self.weights[None, :, :]) + torch.squeeze( F.linear(my_params, self.bias[None, :]), dim=-1 )
def forward(self, input1, input2): weight = self._weight() if self.training else self.weight_mu bias = self._bias() if self.training else self.bias_mu return F.bilinear(input1, input2, weight, bias)
def forward(self, input1, input2): weight = self._regulize_parameter(self.weight) output = F.bilinear(input1, input2, weight, None) if self.norm: output = normalize_prob(output) return output
def forward(self, input1, input2): result = F.bilinear(input1, input2, self.W, self.bias) result += F.linear(input1, self.V1, None) result += F.linear(input2, self.V2, None) return result
def forward(self, input1: Tensor, input2: Tensor) -> Tensor: input1 = self.quant_handle(input1) input2 = self.quant_handle(input2) self.weight_origin = self.weight.clone() self.weight = self.quant_handle(self.weight) return F.bilinear(input1, input2, self.weight, self.bias)
def forward(self, input): input_square = torch.mul(input, input) return F.bilinear(input, input, self.W) + F.bilinear(input, input_square, self.M) + \ torch.mm(torch.cat((input, input_square), 1), self.V) + self.b