Exemplo n.º 1
0
 def forward(self, x):
     h = x * self.c
     h = self.conv(h)
     if self.act is not None:
         h = self.act(h)
     if self.pixelnorm:
         mean = torch.mean(h * h, 1, keepdim=True)
         dom = torch.rsqrt(mean + self.eps)
         h = h * dom
     return h
Exemplo n.º 2
0
    def rsample(self, sample_shape=torch.Size()):
        # NOTE: This does not agree with scipy implementation as much as other distributions.
        # (see https://github.com/fritzo/notebooks/blob/master/debug-student-t.ipynb). Using DoubleTensor
        # parameters seems to help.

        #   X ~ Normal(0, 1)
        #   Z ~ Chi2(df)
        #   Y = X / sqrt(Z / df) ~ StudentT(df)
        shape = self._extended_shape(sample_shape)
        X = self.df.new(*shape).normal_()
        Z = self._chi2.rsample(sample_shape)
        Y = X * torch.rsqrt(Z / self.df)
        return self.loc + self.scale * Y
Exemplo n.º 3
0
 def forward(self, x):
     h = x.unsqueeze(2).unsqueeze(3)
     if self.normalize_latents:
         mean = torch.mean(h * h, 1, keepdim=True)
         dom = torch.rsqrt(mean + self.eps)
         h = h * dom
     h = self.block0(h, self.depth == 0)
     if self.depth > 0:
         for i in range(self.depth - 1):
             h = F.upsample(h, scale_factor=2)
             h = self.blocks[i](h)
         h = F.upsample(h, scale_factor=2)
         ult = self.blocks[self.depth - 1](h, True)
         if self.alpha < 1.0:
             if self.depth > 1:
                 preult_rgb = self.blocks[self.depth - 2].toRGB(h)
             else:
                 preult_rgb = self.block0.toRGB(h)
         else:
             preult_rgb = 0
         h = preult_rgb * (1-self.alpha) + ult * self.alpha
     return h
Exemplo n.º 4
0
 def forward(self, _inp):
     x, mask = _inp
     norm = torch.rsqrt((x**2).mean(dim=1, keepdim=True) + 1e-7)
     return x * norm, mask
Exemplo n.º 5
0
 def forward(self, x):
     x = x - torch.mean(x, (2, 3), True)
     tmp = torch.mul(x, x)  # or x ** 2
     tmp = torch.rsqrt(torch.mean(tmp, (2, 3), True) + self.epsilon)
     return x * tmp
Exemplo n.º 6
0
 def forward(self, x):
     x -= torch.mean(x, dim=[2, 3], keepdim=True)
     x *= torch.rsqrt(
         torch.mean(x**2, dim=[2, 3], keepdim=True) + self.epsilon)
     return x
Exemplo n.º 7
0
 def test_rsqrt(self):
     x = torch.randn(3, 4, requires_grad=True)
     self.assertONNX(lambda x: torch.rsqrt(x), x)
Exemplo n.º 8
0
 def forward(self, input):
     return input * torch.rsqrt(
         torch.mean(input**2, dim=1, keepdim=True) + 1e-8)
Exemplo n.º 9
0
 def forward(ctx, tensor, alpha=1): #pylint: disable=arguments-differ
     """Calculates the forward pass for an ISRLU unit"""
     negatives = torch.min(tensor, torch.tensor((0,), dtype=tensor.dtype))
     nisr = torch.rsqrt(1. + alpha * (negatives ** 2))
     ctx.save_for_backward(nisr)
     return tensor * nisr
Exemplo n.º 10
0
def normalize(x, eps=1e-10):
    return x * torch.rsqrt(torch.sum(x**2, dim=1, keepdim=True) + eps)
Exemplo n.º 11
0
    def forward(self, x):
        self._check_input_dim(x)
        if self.training:
            N, C, H, W = x.size()
            G = self.groups
            x = x.transpose(0, 1).contiguous().view(C, -1)
            mu = x.mean(1, keepdim=True)
            x = x - mu
            xxt = torch.mm(x, x.t())/(N*H*W) + torch.eye(C, out=torch.empty_like(x)) * self.eps

            assert C % G == 0
            length = int(C / G)
            xxti = torch.chunk(xxt, G, dim=0)
            xxtj = [torch.chunk(xxti[j], G, dim=1)[j] for j in range(G)]

            xg = list(torch.chunk(x, G, dim=0))

            xgr_list = []
            for i in range(G):
                u, e, v = torch.svd(xxtj[i])
                ratio = torch.cumsum(e, 0) / e.sum()
                counter_i = 1
                for j in range(length):
                    if e[j] <= self.eps:  # ratio[j] >= (1 - self.eps) or e[j] <= self.eps:
                        # print('{}/{} eigen-vectors selected'.format(j + 1, length))
                        counter_i = j + 1  # at least keep 99.99% energy
                        break
                subspace = torch.zeros_like(xxtj[i])
                for j in range(counter_i):
                    lambda_ij = e[j]
                    if lambda_ij < 0:
                        print('eigenvalues: ', e)
                        print("Error message: negative SVD lambda_ij {} vs SVD lambda_ij {}..".format(lambda_ij, e[j]))
                        break
                    eigenvector_ij = v[:, j][..., None]
                    subspace += torch.mm(eigenvector_ij, torch.rsqrt(lambda_ij) * eigenvector_ij.t())
                xgr = torch.mm(subspace, xg[i])
                xgr_list.append(xgr)

                with torch.no_grad():
                    running_subspace = self.__getattr__('running_subspace' + str(i))
                    running_subspace.data = (1 - self.momentum) * running_subspace.data + self.momentum * subspace.data
            with torch.no_grad():
                self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mu

            xr = torch.cat(xgr_list, dim=0)
            xr = xr * self.weight + self.bias
            xr = xr.view(C, N, H, W).transpose(0, 1)
            return xr

        else:
            N, C, H, W = x.size()
            x = x.transpose(0, 1).contiguous().view(C, -1)
            x = (x - self.running_mean)
            G = self.groups
            xg = list(torch.chunk(x, G, dim=0))
            for i in range(G):
                subspace = self.__getattr__('running_subspace' + str(i))
                xg[i] = torch.mm(subspace, xg[i])
            x = torch.cat(xg, dim=0)
            x = x * self.weight + self.bias
            x = x.view(C, N, H, W).transpose(0, 1)
            return x
Exemplo n.º 12
0
def cosineSim(wv1, wv2):
    mul = torch.mul(wv1, wv2)
    #mulabs = torch.sqrt(torch.sum(torch.mul(mul, mul)))
    wv1abs = torch.rsqrt(torch.sum(torch.mul(wv1, wv1)))
    wv2abs = torch.rsqrt(torch.sum(torch.mul(wv2, wv2)))
    return mul * wv1abs * wv2abs
Exemplo n.º 13
0
    def forward(self, x):
        self._check_input_dim(x)
        if self.training:
            N, C, H, W = x.size()
            G = self.groups
            x = x.transpose(0, 1).contiguous().view(C, -1)
            mu = x.mean(1, keepdim=True)
            x = x - mu
            xxt = torch.mm(x, x.t()) / (N * H * W) + torch.eye(C, out=torch.empty_like(x)) * self.eps

            assert C % G == 0
            length = int(C/G)
            xxti = torch.chunk(xxt, G, dim=0)
            xxtj = [torch.chunk(xxti[j], G, dim=1)[j] for j in range(G)]

            xg = list(torch.chunk(x, G, dim=0))

            xgr_list = []
            for i in range(G):
                subspace = torch.zeros_like(xxtj[i])
                for j in range(length):
                    # initialize eigenvector with random values
                    eigenvector_ij = self.__getattr__('eigenvector{}-{}'.format(i, j))
                    v = l2normalize(torch.randn_like(eigenvector_ij))
                    eigenvector_ij.data = v.data

                    eigenvector_ij = self.power_layer(xxtj[i], eigenvector_ij)
                    lambda_current = torch.mm(xxtj[i].mm(eigenvector_ij).t(), eigenvector_ij)/torch.mm(eigenvector_ij.t(), eigenvector_ij)
                    if j == 0:
                        lambda_ij = lambda_current
                    elif lambda_ij < lambda_current or lambda_current < self.eps:
                        break
                    else:
                        lambda_ij = lambda_current
                    subspace += torch.mm(eigenvector_ij, torch.rsqrt(lambda_ij).mm(eigenvector_ij.t()))
                    # remove projections on the eigenvectors
                    xxtj[i] = xxtj[i] - torch.mm(xxtj[i], eigenvector_ij.mm(eigenvector_ij.t()))

                xgr = torch.mm(subspace, xg[i])
                xgr_list.append(xgr)

                with torch.no_grad():
                    running_subspace = self.__getattr__('running_subspace' + str(i))
                    running_subspace.data = (1 - self.momentum) * running_subspace.data + self.momentum * subspace.data

            with torch.no_grad():
                self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mu

            xr = torch.cat(xgr_list, dim=0)
            xr = xr * self.weight + self.bias
            xr = xr.view(C, N, H, W).transpose(0, 1)

            return xr

        else:
            N, C, H, W = x.size()
            x = x.transpose(0, 1).contiguous().view(C, -1)
            x = (x - self.running_mean)
            G = self.groups
            xg = list(torch.chunk(x, G, dim=0))
            for i in range(G):
                subspace = self.__getattr__('running_subspace' + str(i))
                xg[i] = torch.mm(subspace, xg[i])
            x = torch.cat(xg, dim=0)
            x = x * self.weight + self.bias
            x = x.view(C, N, H, W).transpose(0, 1)
            return x
Exemplo n.º 14
0
    def forward(self, x):
        self._check_input_dim(x)
        if self.training:
            N, C, H, W = x.size()
            G = self.groups
            x = x.transpose(0, 1).contiguous().view(C, -1)
            mu = x.mean(1, keepdim=True)
            x = x - mu
            xxt = torch.mm(x, x.t()) / (N * H * W) + torch.eye(C, out=torch.empty_like(x)) * self.eps

            assert C % G == 0
            length = int(C/G)
            xxti = torch.chunk(xxt, G, dim=0)
            xxtj = [torch.chunk(xxti[j], G, dim=1)[j] for j in range(G)]

            xg = list(torch.chunk(x, G, dim=0))

            xgr_list = []
            for i in range(G):
                counter_i = 0
                # compute eigenvectors of subgroups no grad
                with torch.no_grad():
                    u, e, v = torch.svd(xxtj[i])
                    ratio = torch.cumsum(e, 0)/e.sum()
                    for j in range(length):
                        if ratio[j] >= (1 - self.eps) or e[j] <= self.eps:
                            print('{}/{} eigen-vectors selected'.format(j + 1, length))
                            print(e[0:counter_i])
                            break
                        eigenvector_ij = self.__getattr__('eigenvector{}-{}'.format(i, j))
                        eigenvector_ij.data = v[:, j][..., None].data
                        counter_i = j + 1

                # feed eigenvectors to Power Iteration Layer with grad and compute whitened tensor
                subspace = torch.zeros_like(xxtj[i])
                for j in range(counter_i):
                    eigenvector_ij = self.__getattr__('eigenvector{}-{}'.format(i, j))
                    eigenvector_ij = self.power_layer(xxtj[i], eigenvector_ij)
                    lambda_ij = torch.mm(xxtj[i].mm(eigenvector_ij).t(), eigenvector_ij)/torch.mm(eigenvector_ij.t(), eigenvector_ij)
                    if lambda_ij < 0:
                        print('eigenvalues: ', e)
                        print("Warning message: negative PI lambda_ij {} vs SVD lambda_ij {}..".format(lambda_ij, e[j]))
                        break
                    diff_ratio = (lambda_ij - e[j]).abs()/e[j]
                    if diff_ratio > 0.1:
                        break
                    subspace += torch.mm(eigenvector_ij, torch.rsqrt(lambda_ij).mm(eigenvector_ij.t()))
                    xxtj[i] = xxtj[i] - torch.mm(xxtj[i], eigenvector_ij.mm(eigenvector_ij.t()))
                xgr = torch.mm(subspace, xg[i])
                xgr_list.append(xgr)

                with torch.no_grad():
                    running_subspace = self.__getattr__('running_subspace' + str(i))
                    running_subspace.data = (1 - self.momentum) * running_subspace.data + self.momentum * subspace.data

            with torch.no_grad():
                self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mu

            xr = torch.cat(xgr_list, dim=0)
            xr = xr * self.weight + self.bias
            xr = xr.view(C, N, H, W).transpose(0, 1)

            return xr

        else:
            N, C, H, W = x.size()
            x = x.transpose(0, 1).contiguous().view(C, -1)
            x = (x - self.running_mean)
            G = self.groups
            xg = list(torch.chunk(x, G, dim=0))
            for i in range(G):
                subspace = self.__getattr__('running_subspace' + str(i))
                xg[i] = torch.mm(subspace, xg[i])
            x = torch.cat(xg, dim=0)
            x = x * self.weight + self.bias
            x = x.view(C, N, H, W).transpose(0, 1)
            return x
Exemplo n.º 15
0
    def forward(self, input, style, labels=None):
        batch, in_channel, height, width = input.shape

        if not self.fused:
            weight = self.scale * self.weight.squeeze(0)
            if self.conditional_bias:
                style = self.modulation(style, labels)
            else:
                style = self.modulation(style)

            if self.demodulate:
                w = weight.unsqueeze(0) * style.view(batch, 1, in_channel, 1,
                                                     1)
                dcoefs = (w.square().sum((2, 3, 4)) + 1e-8).rsqrt()

            input = input * style.reshape(batch, in_channel, 1, 1)

            if self.upsample:
                weight = weight.transpose(0, 1)
                out = conv2d_gradfix.conv_transpose2d(input,
                                                      weight,
                                                      padding=0,
                                                      stride=2)
                out = self.blur(out)

            elif self.downsample:
                input = self.blur(input)
                out = conv2d_gradfix.conv2d(input, weight, padding=0, stride=2)

            else:
                out = conv2d_gradfix.conv2d(input,
                                            weight,
                                            padding=self.padding)

            if self.demodulate:
                out = out * dcoefs.view(batch, -1, 1, 1)

            return out

        if self.conditional_bias:
            style = self.modulation(style,
                                    labels).view(batch, 1, in_channel, 1, 1)
        else:
            style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
        weight = self.scale * self.weight * style

        if self.demodulate:
            demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
            weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)

        weight = weight.view(batch * self.out_channel, in_channel,
                             self.kernel_size, self.kernel_size)

        if self.upsample:
            input = input.view(1, batch * in_channel, height, width)
            weight = weight.view(batch, self.out_channel, in_channel,
                                 self.kernel_size, self.kernel_size)
            weight = weight.transpose(1, 2).reshape(batch * in_channel,
                                                    self.out_channel,
                                                    self.kernel_size,
                                                    self.kernel_size)
            out = conv2d_gradfix.conv_transpose2d(input,
                                                  weight,
                                                  padding=0,
                                                  stride=2,
                                                  groups=batch)
            _, _, height, width = out.shape
            out = out.view(batch, self.out_channel, height, width)
            out = self.blur(out)

        elif self.downsample:
            input = self.blur(input)
            _, _, height, width = input.shape
            input = input.view(1, batch * in_channel, height, width)
            out = conv2d_gradfix.conv2d(input,
                                        weight,
                                        padding=0,
                                        stride=2,
                                        groups=batch)
            _, _, height, width = out.shape
            out = out.view(batch, self.out_channel, height, width)

        else:
            input = input.view(1, batch * in_channel, height, width)
            out = conv2d_gradfix.conv2d(input,
                                        weight,
                                        padding=self.padding,
                                        groups=batch)
            _, _, height, width = out.shape
            out = out.view(batch, self.out_channel, height, width)

        return out
Exemplo n.º 16
0
#     # variance = torch.var(prob_i, 0)
#     # print(variance)
#     # variance = torch.rsqrt(variance)
#     variance = torch.log(variance)
#     print(variance)

# a1=torch.tensor([5.0])
# print(a1-torch.mean(a1))
# print(torch.norm(a1-torch.mean(a1),p=2))
# print(torch.log(1/torch.norm(a1-torch.mean(a1),p=2)))
# print(torch.var(a1,0))
# print(a[0].index_select(0,torch.tensor([0,2,4])))

# print(torch.index_select(a, mask))
# print(a.index_select(0,mask))

#方差
var = torch.var(a, 1)

var2 = torch.var_mean(a, 1)
#开根号倒数
var3 = torch.rsqrt(var)
b = torch.nn.functional.softmax(a, 1)

# b = tensor([[4, 3, 2, 1, ], [1, 2, 3, 4]])
# a = a.cuda()
# b = b.cuda()
# print(a)
# a_soft = F.softmax(a, 1)
# print(a_soft)
Exemplo n.º 17
0
 def forward(self, x):
     return x + torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
Exemplo n.º 18
0
    def step(self, steps=1):
        """
        Take a projection step.
        Arguments:
            steps (int): Number of steps to take. If this
                exceeds the remaining steps of the projection
                that amount of steps is taken instead. Default
                value is 1.
        """
        self._check_job()

        remaining_steps = self._job.num_steps - self._job.current_step
        if not remaining_steps > 0:
            warnings.warn(
                'Trying to take a projection step after the ' + \
                'final projection iteration has been completed.'
            )
        if steps < 0:
            steps = remaining_steps
        steps = min(remaining_steps, steps)

        if not steps > 0:
            return

        for _ in range(steps):

            if self._job.current_step >= self._job.num_steps:
                break

            # Hyperparameters.
            t = self._job.current_step / self._job.num_steps
            noise_strength = self._dlatent_std * self._job.initial_noise_factor \
                             * max(0.0, 1.0 - t / self._job.noise_ramp_length) ** 2
            lr_ramp = min(1.0, (1.0 - t) / self._job.lr_rampdown_length)
            lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
            lr_ramp = lr_ramp * min(1.0, t / self._job.lr_rampup_length)
            learning_rate = self._job.initial_learning_rate * lr_ramp

            for param_group in self._job.opt.param_groups:
                param_group['lr'] = learning_rate

            dlatents = self._job.dlatent_param + noise_strength * self._job.noise_tensor.normal_(
            )

            output = self.G_synthesis(dlatents)
            assert output.size() == self._job.target.size(), \
                'target size {} does not fit output size {} of generator'.format(
                    target.size(), output.size())

            output_scaled = self._scale_for_lpips(output)

            # Main loss: LPIPS distance of output and target
            lpips_distance = torch.mean(
                self.lpips_model(output_scaled, self._job.target_scaled))

            # Calculate noise regularization loss
            reg_loss = 0
            for p in self._job.noise_params:
                size = min(p.size()[2:])
                dim = p.dim() - 2
                while True:
                    reg_loss += torch.mean(
                        (p * p.roll(shifts=[1] * dim,
                                    dims=list(range(2, 2 + dim))))**2)
                    if size <= 8:
                        break
                    p = F.interpolate(p, scale_factor=0.5, mode='area')
                    size = size // 2

            # Combine loss, backward and update params
            loss = lpips_distance + self._job.regularize_noise_weight * reg_loss
            self._job.opt.zero_grad()
            loss.backward()
            self._job.opt.step()

            # Normalize noise values
            for p in self._job.noise_params:
                with torch.no_grad():
                    p_mean = p.mean(dim=list(range(1, p.dim())), keepdim=True)
                    p_rstd = torch.rsqrt(
                        torch.mean((p - p_mean)**2,
                                   dim=list(range(1, p.dim())),
                                   keepdim=True) + 1e-8)
                    p.data = (p.data - p_mean) * p_rstd

            self._job.current_step += 1

            if self._job.verbose:
                self._job.value_tracker.add('loss', float(loss))
                self._job.value_tracker.add('lpips_distance',
                                            float(lpips_distance))
                self._job.value_tracker.add('noise_reg', float(reg_loss))
                self._job.value_tracker.add('lr', learning_rate, beta=0)
                self._job.progress.write(self._job.verbose_prefix,
                                         str(self._job.value_tracker))
                if self._job.current_step >= self._job.num_steps:
                    self._job.progress.close()
Exemplo n.º 19
0
 def forward(self, x):
     mean = torch.mean(x * x, 1, keepdim=True)
     dom = torch.rsqrt(mean + self.eps)
     x = x * dom
     return x
Exemplo n.º 20
0
# real
torch.randn(4, dtype=torch.cfloat).real

# reciprocal
torch.reciprocal(a)

# remainder
torch.remainder(torch.tensor([-3., -2, -1, 1, 2, 3]), 2)
torch.remainder(torch.tensor([1, 2, 3, 4, 5]), 1.5)

# round
torch.round(a)

# rsqrt
torch.rsqrt(a)

# sigmoid
torch.sigmoid(a)

# sign
torch.sign(torch.tensor([0.7, -1.2, 0., 2.3]))

# sgn
torch.tensor([3 + 4j, 7 - 24j, 0, 1 + 2j]).sgn()

# signbit
torch.signbit(torch.tensor([0.7, -1.2, 0., 2.3]))

# sin
torch.sin(a)
Exemplo n.º 21
0
 def forward(self, x):
     mean = x.mean(-1, keepdim=True)
     s = (x - mean).pow(2).mean(-1, keepdim=True)
     x = (x - mean) * torch.rsqrt(s + self.eps)
     return self.scale * x + self.shift
Exemplo n.º 22
0
grad_sbn_c_last = grad_output_t.clone().transpose(-1, 1).contiguous().detach()
out_sbn_c_last = sbn_c_last(inp_sbn_c_last)
out_sbn_c_last.backward(grad_sbn_c_last)

sbn_result = True
sbn_result_c_last = True
bn_result = True

sbn_result = compare("comparing mean: ", mean, m, error) and sbn_result
#sbn_result = compare("comparing variance: ", var, unb_v, error) and sbn_result
sbn_result = compare("comparing biased variance: ", var_biased, b_v,
                     error) and sbn_result

out = syncbn.batchnorm_forward(inp_t, mean, inv_std, weight_t, bias_t)
out_r = weight_r * (
    inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1, 1, 1) + eps) + bias_r

sbn_result = compare("comparing output: ", out, out_r, error) and sbn_result
compare("comparing bn output: ", out_bn, out_r, error)

grad_output_t = type_tensor(grad)

grad_output_r = ref_tensor(
    grad.transpose(1, 0, 2, 3).reshape(feature_size, -1))
grad_output2_r = ref_tensor(grad)

grad_bias_r = grad_output_r.sum(1)
grad_weight_r = ((inp2_r - m.view(-1, 1, 1)) *
                 torch.rsqrt(b_v.view(-1, 1, 1) + eps) *
                 grad_output2_r).transpose(1, 0).contiguous().view(
                     feature_size, -1).sum(1)
Exemplo n.º 23
0
def normalize_adj(mx):
    rowsum = mx.sum(1, keepdim=False)
    r_inv_sqrt = torch.rsqrt(rowsum)
    r_inv_sqrt[torch.isinf(r_inv_sqrt)] = 0.
    r_mat_inv_sqrt = torch.diag(r_inv_sqrt)
    return mx.mm(r_mat_inv_sqrt).transpose(0, 1).mm(r_mat_inv_sqrt)
def test(parser, visualisation=None):

    parser = updateParser(parser)

    kwargs = vars(parser.parse_args())

    # Parameters
    name = getVal(kwargs, "name", None)
    if name is None:
        raise ValueError("You need to input a name")

    module = getVal(kwargs, "module", None)
    if module is None:
        raise ValueError("You need to input a module")

    imgPath = getVal(kwargs, "inputImage", None)
    if imgPath is None:
        raise ValueError("You need to input an image path")

    scale = getVal(kwargs, "scale", None)
    iter = getVal(kwargs, "iter", None)
    nRuns = getVal(kwargs, "nRuns", 1)

    checkPointDir = os.path.join(kwargs["dir"], name)
    checkpointData = getLastCheckPoint(checkPointDir,
                                       name,
                                       scale=scale,
                                       iter=iter)
    weights = getVal(kwargs, 'weights', None)

    if checkpointData is None:
        raise FileNotFoundError(
            "No checkpoint found for model " + str(name) + " at directory "
            + str(checkPointDir) + 'cwd=' + str(os.getcwd()))

    modelConfig, pathModel, _ = checkpointData

    keysLabels = None
    with open(modelConfig, 'rb') as file:
        keysLabels = json.load(file)["attribKeysOrder"]
    if keysLabels is None:
        keysLabels = {}

    packageStr, modelTypeStr = getNameAndPackage(module)
    modelType = loadmodule(packageStr, modelTypeStr)

    visualizer = GANVisualizer(
        pathModel, modelConfig, modelType, visualisation)

    # Load the image
    targetSize = visualizer.model.getSize()

    baseTransform = standardTransform(targetSize)

    img = pil_loader(imgPath)
    input = baseTransform(img)
    input = input.view(1, input.size(0), input.size(1), input.size(2))

    pathsModel = getVal(kwargs, "featureExtractor", None)
    featureExtractors = []
    imgTransforms = []

    if weights is not None:
        if pathsModel is None or len(pathsModel) != len(weights):
            raise AttributeError(
                "The number of weights must match the number of models")

    if pathsModel is not None:
        for path in pathsModel:
            if path == "id":
                featureExtractor = IDModule()
                imgTransform = IDModule()
            else:
                featureExtractor, mean, std = buildFeatureExtractor(
                    path, resetGrad=True)
                imgTransform = FeatureTransform(mean, std, size=128)  # None)
            featureExtractors.append(featureExtractor)
            imgTransforms.append(imgTransform)
    else:
        featureExtractors = IDModule()
        imgTransforms = IDModule()

    basePath = os.path.splitext(imgPath)[0] + "_" + kwargs['suffix']

    if not os.path.isdir(basePath):
        os.mkdir(basePath)

    basePath = os.path.join(basePath, os.path.basename(basePath))

    print("All results will be saved in " + basePath)

    outDictData = {}
    outPathDescent = None

    fullInputs = torch.cat([input for x in range(nRuns)], dim=0)

    if kwargs['save_descent']:
        outPathDescent = os.path.join(
            os.path.dirname(basePath), "descent")
        if not os.path.isdir(outPathDescent):
            os.mkdir(outPathDescent)

    img, outVectors, loss = gradientDescentOnInput(visualizer.model,
                                                   fullInputs,
                                                   featureExtractors,
                                                   imgTransforms,
                                                   visualizer=visualisation,
                                                   lambdaD=kwargs['lambdaD'],
                                                   nSteps=kwargs['nSteps'],
                                                   weights=weights,
                                                   randomSearch=kwargs['random_search'],
                                                   nevergrad=kwargs['nevergrad'],
                                                   lr=kwargs['learningRate'],
                                                   outPathSave=outPathDescent)

    pathVectors = basePath + "vector.pt"
    torch.save(outVectors, open(pathVectors, 'wb'))

    path = basePath + ".jpg"
    visualisation.saveTensor(img, (img.size(2), img.size(3)), path)
    outDictData[os.path.splitext(os.path.basename(path))[0]] = \
        [x.item() for x in loss]

    outVectors = outVectors.view(outVectors.size(0), -1)
    outVectors *= torch.rsqrt((outVectors**2).mean(dim=1, keepdim=True))

    barycenter = outVectors.mean(dim=0)
    barycenter *= torch.rsqrt((barycenter**2).mean())
    meanAngles = (outVectors * barycenter).mean(dim=1)
    meanDist = torch.sqrt(((barycenter-outVectors)**2).mean(dim=1)).mean(dim=0)
    outDictData["Barycenter"] = {"meanDist": meanDist.item(),
                                 "stdAngles": meanAngles.std().item(),
                                 "meanAngles": meanAngles.mean().item()}

    path = basePath + "_data.json"
    outDictData["kwargs"] = kwargs

    with open(path, 'w') as file:
        json.dump(outDictData, file, indent=2)

    pathVectors = basePath + "vectors.pt"
    torch.save(outVectors, open(pathVectors, 'wb'))
Exemplo n.º 25
0
def zca_mean(
    inp: torch.Tensor,
    dim: int = 0,
    unbiased: bool = True,
    eps: float = 1e-6,
    return_inverse: bool = False
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
    r"""

    Computes the ZCA whitening matrix and mean vector. The output can be used with
    :py:meth:`~kornia.color.linear_transform`

    See :class:`~kornia.color.ZCAWhitening` for details.


    args:
        inp (torch.Tensor) : input data tensor
        dim (int): Specifies the dimension that serves as the samples dimension. Default = 0
        unbiased (bool): Whether to use the unbiased estimate of the covariance matrix. Default = True
        eps (float) : a small number used for numerical stability. Default = 0
        return_inverse (bool): Whether to return the inverse ZCA transform.

    shapes:
        - inp: :math:`(D_0,...,D_{\text{dim}},...,D_N)` is a batch of N-D tensors.
        - transform_matrix: :math:`(\Pi_{d=0,d\neq \text{dim}}^N D_d, \Pi_{d=0,d\neq \text{dim}}^N D_d)`
        - mean_vector: :math:`(1, \Pi_{d=0,d\neq \text{dim}}^N D_d)`
        - inv_transform: same shape as the transform matrix

    returns:
        Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        A tuple containing the ZCA matrix and the mean vector. If return_inverse is set to True,
        then it returns the inverse ZCA matrix, otherwise it returns None.

    Examples:
        >>> x = torch.tensor([[0,1],[1,0],[-1,0],[0,-1]], dtype = torch.float32)
        >>> transform_matrix, mean_vector,_ = zca_mean(x) # Returns transformation matrix and data mean
        >>> x = torch.rand(3,20,2,2)
        >>> transform_matrix, mean_vector, inv_transform = zca_mean(x, dim = 1, return_inverse = True)
        >>> # transform_matrix.size() equals (12,12) and the mean vector.size equal (1,12)

    """

    if not isinstance(inp, torch.Tensor):
        raise TypeError("Input type is not a torch.Tensor. Got {}".format(
            type(inp)))

    if not isinstance(eps, float):
        raise TypeError(f"eps type is not a float. Got{type(eps)}")

    if not isinstance(unbiased, bool):
        raise TypeError(f"unbiased type is not bool. Got{type(unbiased)}")

    if not isinstance(dim, int):
        raise TypeError("Argument 'dim' must be of type int. Got {}".format(
            type(dim)))

    if not isinstance(return_inverse, bool):
        raise TypeError(
            "Argument return_inverse must be of type bool {}".format(
                type(return_inverse)))

    inp_size = inp.size()

    if dim >= len(inp_size) or dim < -len(inp_size):
        raise IndexError(
            "Dimension out of range (expected to be in range of [{},{}], but got {}"
            .format(-len(inp_size),
                    len(inp_size) - 1, dim))

    if dim < 0:
        dim = len(inp_size) + dim

    feat_dims = torch.cat(
        [torch.arange(0, dim),
         torch.arange(dim + 1, len(inp_size))])

    new_order: List[int] = torch.cat([torch.tensor([dim]), feat_dims]).tolist()

    inp_permute = inp.permute(new_order)

    N = inp_size[dim]
    feature_sizes = torch.tensor(inp_size[0:dim] + inp_size[dim + 1::])
    num_features: int = int(torch.prod(feature_sizes).item())

    mean: torch.Tensor = torch.mean(inp_permute, dim=0, keepdim=True)

    mean = mean.reshape((1, num_features))

    inp_center_flat: torch.Tensor = inp_permute.reshape(
        (N, num_features)) - mean

    cov = inp_center_flat.t().mm(inp_center_flat)

    if unbiased:
        cov = cov / float(N - 1)
    else:
        cov = cov / float(N)

    U, S, _ = _torch_svd_cast(cov)

    S = S.reshape(-1, 1)
    S_inv_root: torch.Tensor = torch.rsqrt(S + eps)
    T: torch.Tensor = (U).mm(S_inv_root * U.t())

    T_inv: Optional[torch.Tensor] = None
    if return_inverse:
        T_inv = (U).mm(torch.sqrt(S + eps) * U.t())

    return T, mean, T_inv
Exemplo n.º 26
0
 def forward(self, x):
     return x * torch.rsqrt(
         torch.mean(x**2, dim=1, keepdim=True) + self.epsilon)
Exemplo n.º 27
0
def instance_norm(x, eps=1e-8):
    """Instance normalization. """
    assert len(x.shape) == 4, "shape of input should be NCHW!"
    x -= torch.mean(x, dim=(2, 3), keepdim=True)
    return x * torch.rsqrt(torch.mean(x ** 2, dim=(2, 3), keepdim=True) + eps)
Exemplo n.º 28
0
    def forward(self, x):
        tmp = torch.mul(x, x)  # or x ** 2
        tmp1 = torch.rsqrt(torch.mean(tmp, dim=1, keepdim=True) + self.epsilon)

        return x * tmp1
Exemplo n.º 29
0
    def forward(self, x, abs_x, deg, idx):
        batch, channels, npoints, neighbors = x.size()  # B, C, N, K
        ''' 1. get point features (B, C, N) '''
        x_q = abs_x  # B, C//2, N, 1
        x_kv = x  # B, C, N, K
        ''' 2. transform by Wq, Wk, Wv '''
        q_out = self.query_conv(x_q)  # B, C, N, 1
        k_out = self.key_conv(x_kv)  # B, C, N, K
        k_out_all = k_out[:, :, :, 0]  # B, C, N, 1
        v_out = self.value_conv(x_kv)  # B, C, N, K
        v_out_all = v_out[:, :, :, 0]  # B, C, N, 1
        ''' 3. relative positional encoding '''
        if self.rpe:
            k_out = k_out + self.rel_k

        # k_out : B, C, N, K / self.rel_k : C, 1, K
        ''' 4. multi-head attention '''
        if self.scale:
            scaler = torch.tensor([self.out_channels / self.groups]).cuda()
            out = torch.rsqrt(scaler) * q_out * k_out
        else:
            out = q_out * k_out  # B, C, N, K

        out = F.softmax(out, dim=-1)  # B, C, N, K
        ''' 5. scoring '''
        idx = idx[:, :, :, -1].unsqueeze(1).expand_as(
            out).cuda()  # B, N, K -> B, 1, N, K -> B, C, N, K
        idx_scatter = torch.zeros(batch,
                                  self.out_channels,
                                  npoints,
                                  npoints,
                                  device='cuda').detach()  # B, C, N, N

        # node-wise importance
        idx_scatter.scatter_(dim=3, index=idx,
                             src=out)[0, 0, 0, :]  # B,C,N,N -> B,C,N,1
        score = idx_scatter.sum(dim=2, keepdim=True).transpose(2, 3).squeeze(
            3)  # B, C, 1, N -> B, C, N, 1 -> B, C, N

        idx_key, idx_salient = score.topk(
            k=20, dim=-1)  # B, C, S | here, S : sampled global points

        k_out_all = torch.gather(k_out_all, 2,
                                 idx_salient).unsqueeze(2).repeat(
                                     1, 1, npoints, 1)  # B, C, S -> B, C, N, S
        v_out_all = torch.gather(v_out_all, 2,
                                 idx_salient).unsqueeze(2).repeat(
                                     1, 1, npoints, 1)  # B, C, S -> B, C, N, S

        out_all = q_out * k_out_all  # B, C, N, S
        out_all = F.softmax(out_all, dim=-1)  # B, C, N, S
        out_all = torch.einsum('bcns,bcns -> bcn', out_all,
                               v_out_all)  # B, C, N
        out_all = out_all.view(batch, -1, npoints, 1)  # B, C, N, 1

        # x : 6 x 3  /  idx : 6
        # x : B,C,N x K / idx : B,C,N,K

        out = torch.einsum('bcnk,bcnk -> bcn', out, v_out)  # b, C, N, K
        out = out.view(batch, -1, npoints, 1)  # b, C, N, 1
        ''' 8. reshape for memory '''
        #out = torch.cat([out, out_all], dim = 1)
        if self.layer > 0:
            #print("ratio : {}".format(out_all.mean()/out.mean()))
            out = torch.cat([out, out_all], dim=1)

        if self.return_kv:
            k_out = k_out.view(batch, -1, npoints, neighbors, 1)
            v_out = v_out.view(batch, -1, npoints, neighbors, 1)
            return out, k_out, v_out

        else:
            return out
Exemplo n.º 30
0
 def forward(self, x):
     x2 = x**2
     temp = torch.mean(x2, dim=1, keepdim=True)
     return x * torch.rsqrt(temp + self.epsilon)
Exemplo n.º 31
0
 def test_rqrt(x, y):
     c = torch.rsqrt(torch.add(x, y))
     return c
Exemplo n.º 32
0
    def forward(self, x, style):
        """Forward function.

        Args:
            x ([Tensor): Input features with shape of (N, C, H, W).
            style (Tensor): Style latent with shape of (N, C).

        Returns:
            Tensor: Output feature with shape of (N, C, H, W).
        """
        n, c, h, w = x.shape
        # process style code
        style = self.style_modulation(style).view(n, 1, c, 1,
                                                  1) + self.style_bias

        # combine weight and style
        weight = self.weight * style
        if self.demodulate:
            demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
            weight = weight * demod.view(n, self.out_channels, 1, 1, 1)

        weight = weight.view(n * self.out_channels, c, self.kernel_size,
                             self.kernel_size)

        if self.upsample and not self.deconv2conv:
            x = x.reshape(1, n * c, h, w)
            weight = weight.view(n, self.out_channels, c, self.kernel_size,
                                 self.kernel_size)
            weight = weight.transpose(1, 2).reshape(n * c, self.out_channels,
                                                    self.kernel_size,
                                                    self.kernel_size)
            x = conv_transpose2d(x, weight, padding=0, stride=2, groups=n)
            x = x.reshape(n, self.out_channels, *x.shape[-2:])
            x = self.blur(x)
        elif self.upsample and self.deconv2conv:
            if self.up_after_conv:
                x = x.reshape(1, n * c, h, w)
                x = conv2d(x, weight, padding=self.padding, groups=n)
                x = x.view(n, self.out_channels, *x.shape[2:4])

            if self.with_interp_pad:
                h_, w_ = x.shape[-2:]
                up_cfg_ = deepcopy(self.up_config)
                up_scale = up_cfg_.pop('scale_factor')
                size_ = (h_ * up_scale + self.interp_pad,
                         w_ * up_scale + self.interp_pad)
                x = F.interpolate(x, size=size_, **up_cfg_)
            else:
                x = F.interpolate(x, **self.up_config)

            if not self.up_after_conv:
                h_, w_ = x.shape[-2:]
                x = x.view(1, n * c, h_, w_)
                x = conv2d(x, weight, padding=self.padding, groups=n)
                x = x.view(n, self.out_channels, *x.shape[2:4])

        elif self.downsample:
            x = self.blur(x)
            x = x.view(1, n * self.in_channels, *x.shape[-2:])
            x = conv2d(x, weight, stride=2, padding=0, groups=n)
            x = x.view(n, self.out_channels, *x.shape[-2:])
        else:
            x = x.view(1, n * c, h, w)
            x = conv2d(x, weight, stride=1, padding=self.padding, groups=n)
            x = x.view(n, self.out_channels, *x.shape[-2:])

        return x
Exemplo n.º 33
0
    def forward(self, input, style, coords=None, calc_flops=False):
        batch, in_channel, height, width = input.shape

        if calc_flops:
            flops = self.get_flops(input, style)
        else:
            flops = 0

        # Special case for spatially-shaped style
        # Here, we early justify whether the whole feature uses the same style.
        # If that's the case, we simply use the same style, otherwise, it will use another slower logic.
        if (style is not None) and (style.ndim == 4):
            mean_style = style.mean([2, 3], keepdim=True)
            is_mono_style = ((style - mean_style) < 1e-8).all()
            if is_mono_style:
                style = mean_style.squeeze()

        # Regular forward
        if style.ndim == 2:
            style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
            # (1, ) * (1, out_ch, in_ch, k, k) * (B, 1, in_ch, 1, 1)
            # => (B, out_ch, in_ch, k, k)
            weight = self.scale * self.weight * style

            if self.demodulate:
                demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
                weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)

            weight = weight.view(batch * self.out_channel, in_channel,
                                 self.kernel_size, self.kernel_size)

            if self.upsample:
                input = input.view(1, batch * in_channel, height, width)
                weight = weight.view(batch, self.out_channel, in_channel,
                                     self.kernel_size, self.kernel_size)
                weight = weight.transpose(1,
                                          2).reshape(batch * in_channel,
                                                     self.out_channel,
                                                     self.kernel_size,
                                                     self.kernel_size)
                out = F.conv_transpose2d(input,
                                         weight,
                                         padding=0,
                                         stride=2,
                                         groups=batch)
                if self.no_zero_pad:
                    out = out[:, :, 1:-1, 1:
                              -1]  # Clipping head and tail, which involves zero padding
                _, _, height, width = out.shape
                out = out.view(batch, self.out_channel, height, width)
                out = self.blur(out)

            elif self.downsample:
                input = self.blur(input)
                _, _, height, width = input.shape
                input = input.view(1, batch * in_channel, height, width)
                out = F.conv2d(input,
                               weight,
                               padding=self.padding,
                               stride=2,
                               groups=batch)
                _, _, height, width = out.shape
                out = out.view(batch, self.out_channel, height, width)

            else:
                input = input.view(1, batch * in_channel, height, width)
                out = F.conv2d(input,
                               weight,
                               padding=self.padding,
                               groups=batch)
                _, _, height, width = out.shape
                out = out.view(batch, self.out_channel, height, width)
        else:
            assert (not self.training), \
                "Only accepts spatially-shaped global-latent for testing-time manipulation!"
            assert (style.ndim == 4), \
                "Only considered BxCxHxW case, but got shape {}".format(style.shape)

            # For simplicity (and laziness), we sometimes feed spatial latents
            # that are larger than the input, center-crop for such kind of cases.
            style = self._auto_shape_align(source=style, target=input)

            # [Note]
            # Original (lossy expression):   input * (style * weight)
            # What we equivalently do here (still lossy): (input * style) * weight
            sb, sc, sh, sw = style.shape
            flat_style = style.permute(0, 2, 3, 1).reshape(-1,
                                                           sc)  # (BxHxW, C)
            style_mod = self.modulation(flat_style)  # (BxHxW, C)
            style_mod = style_mod.view(sb, sh, sw, self.in_channel).permute(
                0, 3, 1, 2)  # (B, C, H, W)

            input_st = (style_mod * input)  # (B, C, H, W)
            weight = self.scale * self.weight

            if self.demodulate:
                # [Hubert]
                # This will be an estimation if spatilly fused styles are different.
                # In practice, the interpolation of styles do not (numerically) change drastically, so the approximation here is invisible.
                """
                # This is the implementation we shown in the paper Appendix, the for-loop is slow.
                # But this version surely allocates a constant amount of memory.
                for i in range(sh):
                    for j in range(sw):
                        style_expand_s = style_mod[:, :, i, j].view(sb, 1, self.in_channel, 1, 1) # shape: (B, 1, in_ch, 1, 1)
                        simulated_weight_s = weight * style_expand_s # shape: (B, out_ch, in_ch, k, k)
                        demod_s[:, :, i, j] = torch.rsqrt(simulated_weight_s.pow(2).sum([2, 3, 4]) + 1e-8) # shape: (B, out_ch)
                """
                """
                Logically equivalent version, omits one for-loop by batching one spatial dimension.
                """
                demod = torch.zeros(sb, self.out_channel, sh,
                                    sw).to(style.device)
                for i in range(sh):
                    style_expand = style_mod[:, :, i, :].view(
                        sb, 1, self.in_channel,
                        sw).pow(2)  # shape: (B, 1, in_ch, W)
                    weight_expand = weight.pow(2).sum([3, 4]).unsqueeze(
                        -1)  # shape: (B, out_ch, in_ch, 1)
                    simulated_weight = weight_expand * style_expand  # shape: (B, out_ch, in_ch, W)
                    demod[:, :,
                          i, :] = torch.rsqrt(simulated_weight.sum(2) +
                                              1e-8)  # shape: (B, out_ch, W)
                """ 
                # An even faster version that batches both height and width dimension, but allocates too much memory that is impractical in reality.
                # For instance, it allocates 40GB memory with shape (8, 512, 128, 3, 3, 31, 31).
                style_expand = style_mod.view(sb, 1, self.in_channel, 1, 1, sh, sw) # (B,      1  in_ch, 1, 1, H, W)
                weight_expand = weight.unsqueeze(5).unsqueeze(6)                    # (B, out_ch, in_ch, k, k, 1, 1)
                simulated_weight = weight_expand * style_expand # shape: (B, out_ch, in_ch, k, k, H, W)
                demod = torch.rsqrt(simulated_weight.pow(2).sum([2, 3, 4]) + 1e-8) # shape: (B, out_ch, H, W)
                """
                """ 
                # Just FYI. If you use the mean style over the patch, it creates blocky artifacts
                mean_style = style_mod.mean([2,3]).view(sb, 1, self.in_channel, 1, 1)
                simulated_weight_ = weight * mean_style # shape: (B, out_ch, in_ch, k, k)
                demod_ = torch.rsqrt(simulated_weight_.pow(2).sum([2, 3, 4]) + 1e-8)
                demod_ = demod_.unsqueeze(2).unsqueeze(3)
                """

            weight = weight.view(self.out_channel, in_channel,
                                 self.kernel_size, self.kernel_size)

            if self.upsample:
                weight = weight.transpose(0, 1).contiguous()
                out = F.conv_transpose2d(input_st,
                                         weight,
                                         padding=0,
                                         stride=2,
                                         groups=1)
                out = out[:, :, 1:-1, 1:
                          -1]  # Clipping head and tail, which involves zero padding
                _, _, height, width = out.shape
                if self.demodulate:
                    demod = F.interpolate(demod,
                                          size=(height, width),
                                          mode="bilinear",
                                          align_corners=True)
                    out = out * demod
                out = self.blur(out)
            elif self.downsample:
                input_st = self.blur(input_st)
                out = F.conv2d(input_st,
                               weight,
                               padding=self.padding,
                               stride=2,
                               groups=1)
                if self.demodulate:
                    raise NotImplementedError("Unused, not implemented!")
                    out = out * demod
            else:
                out = F.conv2d(input_st,
                               weight,
                               padding=self.padding,
                               groups=1)
                if self.demodulate and (self.padding == 0):
                    demod = demod[:, :,
                                  self.dirty_rm_size[0]:-self.dirty_rm_size[0],
                                  self.dirty_rm_size[1]:-self.dirty_rm_size[1]]
                    out = out * demod

            out = out.contiguous()  # Don't know where causes discontiguity.

        return out, flops