Beispiel #1
0
def bilinear_naive(input1, input2, weight, bias=None, conjugate=True):
    r"""Applies a complex bilinear transformation to the incoming complex
    data: :math:`y = x^(T/H) W z + b`.
    """

    n_out = weight.shape[0]

    ww = torch.cat([weight.real, weight.imag], dim=0)
    a, b = input1.real, input1.imag
    u, v = input2.real, input2.imag
    au, av = F.bilinear(a, u, ww, bias=None), F.bilinear(a, v, ww, bias=None)
    bu, bv = F.bilinear(b, u, ww, bias=None), F.bilinear(b, v, ww, bias=None)

    if conjugate:
        pp, qq = au + bv, av - bu
    else:
        pp, qq = au - bv, av + bu

    repp, impp = pp[..., :n_out], pp[..., n_out:]
    reqq, imqq = qq[..., :n_out], qq[..., n_out:]

    output = Cplx(repp - imqq, impp + reqq)
    if bias is not None:
        output += bias

    return output
Beispiel #2
0
    def forward(self, input_left, input_right):
        """

        Args:
            input_left: Tensor
                the left input tensor with shape = [batch1, batch2, ..., left_features]
            input_right: Tensor
                the right input tensor with shape = [batch1, batch2, ..., right_features]

        Returns:

        """

        batch_size = input_left.size()[:-1]
        batch = int(np.prod(batch_size))

        # convert left and right input to matrices [batch, left_features], [batch, right_features]
        input_left = input_left.view(batch, self.left_features)
        input_right = input_right.view(batch, self.right_features)

        # output [batch, out_features]
        output = F.bilinear(input_left, input_right, self.U, self.bias)
        output = output + F.linear(input_left,
                                   self.weight_left, None) + F.linear(
                                       input_right, self.weight_right, None)
        # convert back to [batch1, batch2, ..., out_features]
        return output.view(batch_size + (self.out_features, ))
Beispiel #3
0
    def forward(self, input_left, input_right):
        '''

        Args:
            input_left: Tensor
                the left input tensor with shape = [batch1, batch2, ..., left_features]
            input_right: Tensor
                the right input tensor with shape = [batch1, batch2, ..., right_features]

        Returns:

        '''

        # left_size torch.Size([16, 24, 128]
        left_size = input_left.size()
        right_size = input_right.size()
        assert left_size[:-1] == right_size[:-1], \
            "batch size of left and right inputs mis-match: (%s, %s)" % (left_size[:-1], right_size[:-1])
        batch_size = int(np.prod(left_size[:-1]))
        # batch_size =384 = (16*24)
        # convert left and right input to matrices [batch_size, left_features], [batch_size, right_features]
        # input_left = torch.Size([384, 128])
        input_left = input_left.view(batch_size, self.left_features)
        input_right = input_right.view(batch_size, self.right_features)

        # output [batch_size, out_features]
        # y = x_1*A*x_2 + b
        output = F.bilinear(input_left, input_right, self.U, self.bias)
        output = output + F.linear(input_left, self.W_l, None) + F.linear(
            input_right, self.W_r, None)
        # convert back to [batch1, batch2, ..., out_features]
        return output.view(left_size[:-1] + (self.out_features, ))
Beispiel #4
0
    def forward(self, input_left, input_right):
        '''
        Args:
            input_left: Tensor
                the left input tensor with shape = [batch1, batch2, ..., left_features]
            input_right: Tensor
                the right input tensor with shape = [batch1, batch2, ..., right_features]
        Returns:
        '''

        left_size = input_left.size()
        right_size = input_right.size()
        assert left_size[:-1] == right_size[:-1], \
            "batch size of left and right inputs mis-match: (%s, %s)" % (left_size[:-1], right_size[:-1])
        batch = int(np.prod(left_size[:-1]))

        # convert left and right input to matrices [batch, left_features], [batch, right_features]
        input_left = input_left.view(batch, self.left_features)
        input_right = input_right.view(batch, self.right_features)

        # output [batch, out_features]
        output = F.bilinear(input_left, input_right, self.U, self.bias)
        output = output + F.linear(input_left, self.W_l, None) + F.linear(
            input_right, self.W_r, None)
        # convert back to [batch1, batch2, ..., out_features]
        return output.view(left_size[:-1] + (self.out_features, ))
Beispiel #5
0
 def forward(self, input1, input2, pgn_vector):
     weight = torch.matmul(pgn_vector,
                           self.weight.view(self.in_params, -1)).view(
                               self.out_features, self.in1_features,
                               self.in2_features)
     bias = None
     if self.bias is not None:
         bias = torch.matmul(pgn_vector, self.bias)
     return F.bilinear(input1, input2, weight, bias)
Beispiel #6
0
    def forward(self, input1, input2):
        mu = super().forward(input1, input2)
        if not self.training:
            return mu

        s2 = F.bilinear(input1.real * input1.real + input1.imag * input1.imag,
                        input2.real * input2.real + input2.imag * input2.imag,
                        torch.exp(self.log_sigma2), None)

        return mu + cplx.randn_like(s2) * torch.sqrt(torch.clamp(s2, 1e-8))
Beispiel #7
0
    def forward(self, input1, input2):
        mu = super().forward(input1, input2)
        if not self.training:
            return mu

        s2 = F.bilinear(input1.real * input1.real + input1.imag * input1.imag,
                        input2.real * input2.real + input2.imag * input2.imag,
                        torch.exp(self.log_sigma2), None)

        noise = Cplx(*map(torch.randn_like, (s2, s2))) / sqrt(2)
        return mu + noise * torch.sqrt(torch.clamp(s2, 1e-8))
Beispiel #8
0
 def forward(self, x1: Tensor, x2: Tensor):
     if self.bias_x:
         x1 = torch.cat((x1, torch.ones_like(x1[..., :1])), -1)
     if self.bias_y:
         x2 = torch.cat((x2, torch.ones_like(x2[..., :1])), -1)
     if self.expand:
         # [batch_size, n_out, seq_len, seq_len]
         s = torch.einsum('bxi,oij,byj->boxy', x1, self.weight, x2)
         return s
     # [batch_size, n_out, seq_len]
     return F.bilinear(x1, x2, self.weight, None)
Beispiel #9
0
    def forward(self, input1, input2):
        r"""Forward pass of the SGVB method for a bilinear layer.

        Straightforward generalization of the local reparameterization trick.
        """
        mu = super().forward(input1, input2)
        if not self.training:
            return mu

        s2 = F.bilinear(input1 * input1, input2 * input2,
                        torch.exp(self.log_sigma2), None)

        # .normal reports a grad-fn, but weirdly does not pass grads!
        return mu + torch.randn_like(s2) * torch.sqrt(torch.clamp(s2, 1e-8))
Beispiel #10
0
    def forward(self, input_left, input_right):
        left_size = input_left.size()
        right_size = input_right.size()

        batch = int(np.prod(left_size[:-1]))

        input_left = input_left.view(batch, self.left_features)
        input_right = input_right.view(batch, self.right_features)

        output = F.bilinear(input_left, input_right, self.U, self.bias)
        output = output + F.linear(input_left, self.W_l, None) + F.linear(
            input_right, self.W_r, None)

        return output.view(left_size[:-1] + (self.out_features, ))
Beispiel #11
0
    def forward(self, tensora, tensorb):
        """
        YVec = aHVec @ W1 @ bVVec +  W2 @ aHVec  + W3 @ bVVec

        Args:
            tensora: shape = [dim1, dim2, ..., tensora_dim]
            tensorb: shape = [dim1, dim2, ..., tensorb_dim]
        """
        dims_prev = tensora.size()[:-1]
        dummpy_batch = int(np.prod(dims_prev))

        tensora = tensora.reshape(dummpy_batch, self.tensora_dim)
        tensorb = tensorb.reshape(dummpy_batch, self.tensorb_dim)

        cross = F.bilinear(tensora, tensorb, self.U, self.bias)
        partial_A = F.linear(tensora, self.tensora_weight)  # x A^T
        partial_B = F.linear(tensorb, self.tensorb_weight)  # x A^T

        out = cross + partial_A + partial_B  # [dummpy_batch, outputs_dim]
        out.reshape(dims_prev + (self.outputs_dim, ))

        return out
Beispiel #12
0
def bilinear_cat(input1, input2, weight, bias=None, conjugate=True):
    # [n_out, n_in1, n_in2] -> [2 * n_out, 2 * n_in1, 2 * n_in2]
    U, V = weight.real, weight.imag

    UV = torch.cat([U, -V], dim=2)
    VU = torch.cat([V, U], dim=2)
    if conjugate:
        ww = torch.cat(
            [torch.cat([UV, VU], dim=1),
             torch.cat([VU, -UV], dim=1)], dim=0)
    else:
        ww = torch.cat(
            [torch.cat([UV, -VU], dim=1),
             torch.cat([VU, UV], dim=1)], dim=0)

    x1 = to_concatenated_real(input1, dim=-1)
    x2 = to_concatenated_real(input2, dim=-1)

    output = from_concatenated_real(F.bilinear(x1, x2, ww, None))
    if bias is not None:
        output += bias

    return output
Beispiel #13
0
 def forward(self, in1, in2):
     us = self.left_singular * self.diag
     usv = torch.matmul(us, self.right_singular)
     return F.bilinear(in1, in2, weight=usv)
Beispiel #14
0
 def forward(self, input1, input2):
     dir_ = self.direction
     direction = dir_.div(dir_.pow(2).sum(1).sum(1).sqrt()[:, N_, N_])
     weight = self.scale[:, N_, N_].mul(direction)
     return F.bilinear(input1, input2, weight, self.bias)
Beispiel #15
0
def test_bilinear():
    A = torch.randn(3,5,4)
    l = torch.randn(2,5)
    r = torch.randn(2,4)
    assert (utils.bilinear(l, A, r) == F.bilinear(l, r, A)).all()
Beispiel #16
0
    def cpc_loss(self, gru_input_feats, gru_output_feats, feats_len):
        zt_feats = gru_input_feats.contiguous().view(gru_input_feats.size(0),
                                                     gru_input_feats.size(1),
                                                     -1)
        ct = gru_output_feats.contiguous().view(gru_output_feats.size(0),
                                                gru_output_feats.size(1), -1)
        zt_length = feats_len

        tot_loss = 0
        nb_examples = 0
        lossK = {}  # key=k, value=(tot_loss_k, nb_example_k)
        nbErrK = {}  # key=k, value=(tot_err_k, nb_example_k)

        # change from BxTx(FxC) to TxBxF
        zt_feats = zt_feats.permute(1, 0, 2)
        ct = ct.permute(1, 0, 2)

        for b in range(zt_length.size(0)):
            seq_len = zt_length[b].item()

            # compute indices
            matK = np.arange(self.k + 1)[:, np.newaxis] + np.arange(0, seq_len)
            # example:
            # ct_i    (0 1 2 3 4 5 6)
            # zt_i k0 (1 2 3 4 5 6 7)
            # zt_i k1 (2 3 4 5 6 7 8)
            # ...

            noise = min(self.N, seq_len - 1)
            noiseC_ind = np.arange(seq_len * noise) // noise
            # if noise = 3, produce (0 0 0 1 1 1 2 2 2 ...)

            for k in range(self.k_step, self.k, self.k_step):
                z_ind = matK[k][matK[k] < seq_len]
                # for example if k=1, z_ind = (2 3 4 5 ... seq_len-1)
                in_feats_z = zt_feats[z_ind, b]
                # then select the zt_feats corresponding to thoses indices
                in_feats_c = ct[matK[0][:in_feats_z.size(0)], b]
                # and the ct_feats correesponding to the first line of matK
                # limited to the number of values in z_feats
                # then we wants to learn W as f(x_1, c_1) = exp(z_1.T W_1 c_1)

                noiseInd = np.zeros((in_feats_z.size(0), noise))

                for i, z in enumerate(z_ind):
                    rand = np.random.permutation(seq_len)
                    orig = rand[rand != z]
                    rand = orig[orig < (z + self.n_around)]
                    rand = rand[rand > (z - self.n_around)]
                    if (rand.shape[0] >= noise):
                        n_indices = rand[:noise]
                    else:
                        n_around = noise + 1
                        rand = orig[orig < (z + n_around)]
                        rand = rand[rand > (z - n_around)]
                        n_indices = rand[:noise]
                    # taking random indices (different to the one of the posit.
                    # z_feat, limited to the number of noise that we want
                    noiseInd[i] = n_indices
                    # for each value of z we have noise random indices
                    # in noiseInd matrix

                # noiseC_ind contains  (0 0 0 ...)
                # noise_Ind  contains  (rand(seq_len)!=z rand(seq_len)!=z ...)
                noise_feats_z = zt_feats[
                    noiseInd.reshape(in_feats_z.size(0) * noise), b]
                noise_feats_c = ct[noiseC_ind[:in_feats_z.size(0) * noise], b]

                fxt_all = self.W[k](
                    torch.cat((in_feats_z, noise_feats_z), 0),
                    torch.cat((in_feats_c, noise_feats_c), 0),
                )

                f_x_t_k = fxt_all[:in_feats_z.size(0)]

                # loss = -(torch.log(f_x_t_k.exp() / fxt_all.exp().sum()).sum())
                loss = -(f_x_t_k - torch.logsumexp(fxt_all, dim=0)).sum()

                slen = in_feats_z.size(0)

                if self.cpc_compute_kcer:
                    # classify each elem of the sequence to compute the cer
                    nbErr = 0
                    for pred in range(slen):
                        in_feats = (in_feats_c[pred], in_feats_z[pred])
                        offset = pred * noise
                        noise_f = (
                            noise_feats_c[offset:offset + noise],
                            noise_feats_z[offset:offset + noise],
                        )
                        f_x_t_n = F.bilinear(
                            torch.cat((in_feats[1].unsqueeze(0), noise_f[1]),
                                      0),
                            torch.cat((in_feats[0].unsqueeze(0), noise_f[0]),
                                      0),
                            self.W[k].weight,
                            self.W[k].bias,
                        )
                        probs = f_x_t_n.exp() / f_x_t_n.exp().sum()
                        nbErr += probs.argmax() != 0

                tot_loss += loss
                nb_examples += f_x_t_k.size(0)
                if k not in lossK:
                    lossK[k] = (loss, f_x_t_k.size(0))
                    if self.cpc_compute_kcer:
                        nbErrK[k] = (nbErr, f_x_t_k.size(0))
                else:
                    (tot_loss_k, nb_examples_k) = lossK[k]
                    nb_examples_k += f_x_t_k.size(0)
                    if self.cpc_compute_kcer:
                        (tot_errors, nbEx) = nbErrK[k]
                        tot_errors += nbErr
                        nbErrK[k] = (tot_errors, nb_examples_k)
                    tot_loss_k += loss
                    lossK[k] = (tot_loss_k, nb_examples_k)

        if self.reduction == "sum":
            tot_loss = 0
            for k in lossK.keys():
                (tot_loss_k, nb_examples_k) = lossK[k]
                tot_loss += tot_loss_k / nb_examples_k
        else:
            tot_loss = 0
            nb_examples = 0
            for k in lossK.keys():
                (tot_loss_k, nb_examples_k) = lossK[k]
                tot_loss += tot_loss_k
                nb_examples += nb_examples_k
            tot_loss /= nb_examples

        details = {}
        details["loss"] = tot_loss
        for k in lossK.keys():
            (tot_loss_k, nb_examples_k) = lossK[k]
            if self.cpc_compute_kcer:
                (nbErr, nbEx) = nbErrK[k]
                details["cer_k" + str(k + 1)] = (torch.tensor(nbErr / nbEx) *
                                                 100)
            if self.loss_details:
                details["loss_k" + str(k + 1)] = tot_loss_k / nb_examples_k

        if tot_loss.item() == float("inf") or tot_loss.item() == float("-inf"):
            print("Inf loss !!")

        return tot_loss, details
Beispiel #17
0
 def forward(self, input1, input2, params=None):
     if params is None:
         params = OrderedDict(self.named_parameters())
     bias = params.get('bias', None)
     return F.bilinear(input1, input2, params['weight'], bias)
Beispiel #18
0
 def forward(self, input1, input2):
     return F.bilinear(input1, input2, self.weight_masked, self.bias)
Beispiel #19
0
 def forward(self, my_params, other_params):
     return F.bilinear(other_params[None, :], my_params[None, :], self.weights[None, :, :]) + torch.squeeze(
         F.linear(my_params, self.bias[None, :]), dim=-1
     )
Beispiel #20
0
 def forward(self, input1, input2):
     weight = self._weight() if self.training else self.weight_mu
     bias = self._bias() if self.training else self.bias_mu
     return F.bilinear(input1, input2, weight, bias)
Beispiel #21
0
 def forward(self, input1, input2):
     weight = self._regulize_parameter(self.weight)
     output = F.bilinear(input1, input2, weight, None)
     if self.norm:
         output = normalize_prob(output)
     return output
Beispiel #22
0
 def forward(self, input1, input2):
     result = F.bilinear(input1, input2, self.W, self.bias)
     result += F.linear(input1, self.V1, None)
     result += F.linear(input2, self.V2, None)
     return result
Beispiel #23
0
 def forward(self, input1: Tensor, input2: Tensor) -> Tensor:
     input1 = self.quant_handle(input1)
     input2 = self.quant_handle(input2)
     self.weight_origin = self.weight.clone()
     self.weight = self.quant_handle(self.weight)
     return F.bilinear(input1, input2, self.weight, self.bias)
Beispiel #24
0
 def forward(self, input):
     input_square = torch.mul(input, input)
     return F.bilinear(input, input, self.W) + F.bilinear(input, input_square, self.M) + \
         torch.mm(torch.cat((input, input_square), 1), self.V) + self.b