コード例 #1
0
    def forward(self, x, **kwargs):

        if hasattr(self, 'in_sequence') and self.in_sequence:
            x = x.permute(0, 2, 1)
        x = F.instance_norm(x, self.running_mean, self.running_var,
                            self.weight, self.bias, self.training
                            or not self.track_running_stats, self.momentum,
                            self.eps)

        if hasattr(self, 'in_sequence') and self.in_sequence:
            x = x.permute(0, 2, 1)
        return x
コード例 #2
0
 def test_instance_norm(self):
     inp = torch.randn(1, 3, 32, 32, device='cuda', dtype=self.dtype)
     running_mean = torch.randn(3, device='cuda', dtype=self.dtype)
     running_var = torch.randn(3, device='cuda', dtype=self.dtype)
     output = F.instance_norm(inp,
                              running_mean=running_mean,
                              running_var=running_var,
                              weight=None,
                              bias=None,
                              use_input_stats=True,
                              momentum=0.1,
                              eps=1e-05)
コード例 #3
0
 def forward(self, x):
     x = transpose(x, 1, 2)
     if x.size(0) == 1:  # Sample size == 1
         if x.dim() == 2:
             x = x.unsqueeze(2)
         x = F.instance_norm(x)
         x = x.squeeze(2)
         x = transpose(x, 1, 2)
         return x
     else:
         x = self.layer(x)
         return transpose(x, 1, 2)
コード例 #4
0
    def forward(self, x: Tensor, batch: OptTensor = None) -> Tensor:
        """"""
        if batch is None:
            out = F.instance_norm(
                x.t().unsqueeze(0), self.running_mean, self.running_var,
                self.weight, self.bias, self.training
                or not self.track_running_stats, self.momentum, self.eps)
            return out.squeeze(0).t()

        batch_size = int(batch.max()) + 1

        mean = var = unbiased_var = x  # Dummies.

        if self.training or not self.track_running_stats:
            norm = degree(batch, batch_size, dtype=x.dtype).clamp_(min=1)
            norm = norm.view(-1, 1)
            unbiased_norm = (norm - 1).clamp_(min=1)

            mean = scatter(x, batch, dim=0, dim_size=batch_size,
                           reduce='add') / norm

            x = x - mean[batch]

            var = scatter(x * x,
                          batch,
                          dim=0,
                          dim_size=batch_size,
                          reduce='add')
            unbiased_var = var / unbiased_norm
            var = var / norm

            momentum = self.momentum
            if self.running_mean is not None:
                self.running_mean = (
                    1 - momentum) * self.running_mean + momentum * mean.mean(0)
            if self.running_var is not None:
                self.running_var = (
                    1 - momentum
                ) * self.running_var + momentum * unbiased_var.mean(0)
        else:
            if self.running_mean is not None:
                mean = self.running_mean.view(1, -1).expand(batch_size, -1)
            if self.running_var is not None:
                var = self.running_var.view(1, -1).expand(batch_size, -1)

            x = x - mean[batch]

        out = x / (var + self.eps).sqrt()[batch]

        if self.weight is not None and self.bias is not None:
            out = out * self.weight.view(1, -1) + self.bias.view(1, -1)

        return out
コード例 #5
0
    def forward(self, x):
        d1 = F.leaky_relu(self.d1(x), 0.2)
        d2 = F.instance_norm(F.leaky_relu(self.d2(d1), 0.2))
        d3 = F.instance_norm(F.leaky_relu(self.d3(d2), 0.2))
        d4 = F.instance_norm(F.leaky_relu(self.d4(d3), 0.2))

        x = F.upsample(d4, scale_factor=2)
        x = nn.ZeroPad2d((1, 0, 1, 0))(x)
        x = self.u1(x)
        x = F.tanh(x)
        x = F.instance_norm(x)
        u1 = torch.cat((x, d3), 1)
        x = F.upsample(u1, scale_factor=2)
        x = self.u2(x)
        x = F.tanh(x)
        x = F.instance_norm(x)
        u2 = torch.cat((x, d2), 1)
        x = F.upsample(u2, scale_factor=2)
        x = nn.ZeroPad2d((1, 0, 1, 0))(x)
        x = self.u3(x)
        x = F.tanh(x)
        x = F.instance_norm(x)
        u3 = torch.cat((x, d1), 1)
        #x = F.upsample(x, scale_factor=2)
        #x = nn.ZeroPad2d((1,0,1,0))(x)
        #x = self.u4(x)
        x = F.upsample(u3, scale_factor=2)
        x = nn.ZeroPad2d((2, 1, 2, 1))(x)
        x = F.tanh(self.output(x))
        return x
コード例 #6
0
def compareFunc(x, Kopen, normFlag='batch', eps=1e-5, device=torch.device('cpu')):
    x = conv3x3(x, Kopen)
    if normFlag is 'batch':
        x = F.batch_norm(x,
                         running_mean = torch.zeros(Kopen.size(0)).to(device),
                         running_var  = torch.ones(Kopen.size(0)).to(device),
                         weight       = torch.ones(Kopen.size(0)).to(device),
                         bias         = torch.zeros(Kopen.size(0)).to(device),
                         training = True, eps=eps)
    elif normFlag is 'instance':
        x = F.instance_norm(x, weight=weightBias[0], bias=weightBias[1])
    x = F.relu(x)
    return x
コード例 #7
0
ファイル: cn_net.py プロジェクト: sunbirddy/consac
    def forward(self, inputs, mask=None):
        '''
        Forward pass.

        inputs -- 4D data tensor (BxCxHxW)
        '''
        inputs = torch.transpose(inputs, 1, 2).unsqueeze(-1)

        batch_size = inputs.size(0)
        data_size = inputs.size(2)

        x = inputs[:, 0:self.input_dim]
        x = F.relu(self.p_in(x))

        for r in self.res_blocks:
            res = x
            if mask is None:
                if self.batch_norm:
                    x = F.relu(r[1](F.instance_norm(r[0](x))))
                    x = F.relu(r[3](F.instance_norm(r[2](x))))
                else:
                    x = F.relu(F.instance_norm(r[0](x)))
                    x = F.relu(F.instance_norm(r[1](x)))
            else:
                x = F.relu(r[1](self.masked_instance_norm(r[0](x), mask)))
                x = F.relu(r[3](self.masked_instance_norm(r[2](x), mask)))
            x = x + res

        log_probs = F.logsigmoid(self.p_out(x))

        # normalization
        log_probs = log_probs.view(batch_size, -1)
        normalizer = torch.logsumexp(log_probs, dim=1)
        normalizer = normalizer.unsqueeze(1).expand(-1, data_size)
        log_probs = log_probs - normalizer
        log_probs = log_probs.view(batch_size, 1, data_size, 1)

        return log_probs
コード例 #8
0
ファイル: batch_norm.py プロジェクト: serre-lab/pred_gn
 def forward(self, input):
     # self._check_input_dim(input)
     x = input.reshape([input.size(0), self.num_groups,
                        -1])  #input.size(1)//self.num_groups,-1])
     if self.reset:
         self.running_mean = x.mean((0, 2)).detach()
         self.running_var = x.mean((0, 2)).detach()
         self.num_batches_tracked.zero_()
         self.reset = False
     return F.instance_norm(
         x, self.running_mean, self.running_var, None, None, True,
         self.momentum, self.eps).reshape(
             input.shape) * self.weight[:, None, None] + self.bias[:, None,
                                                                   None]
コード例 #9
0
def adain(y, x):
    """
    Adaptive instance normalization
    :param y:
    :param x:
    :return:
    """
    b, c, h, w = y.size()

    ys = y[:, :c//2, :, :]
    yb = y[:, c//2:, :, :]

    x = F.instance_norm(x)

    return ys * x + yb
コード例 #10
0
def adain(y, x):
    """
    Adaptive instance normalization
    :param y: Parameters for the normalization
    :param x: Input to normalize
    :return:
    """
    b, c, h, w = y.size()

    ys = y[:, :c // 2, :, :]
    yb = y[:, c // 2:, :, :]

    x = F.instance_norm(x)

    return (ys + 1.) * x + yb
コード例 #11
0
    def forward(self, input, ConInfor):
        self._check_input_dim(input)
        b, c = input.size(0), input.size(1)
        out = F.instance_norm(input, self.running_mean, self.running_var, None,
                              None, self.training
                              or not self.track_running_stats, self.momentum,
                              self.eps)

        if self.num_con > 0:
            weight = self.ConAlpha(ConInfor).view(b, c, 1, 1)
            bias = self.ConBeta(ConInfor).view(b, c, 1, 1)
        else:
            weight = 1
            bias = 0
        return out.view(b, c, *input.size()[2:]) * weight + bias
コード例 #12
0
 def forward(self, x):
     d1 = F.leaky_relu(self.d1(x), 0.2)
     x = F.max_pool2d(d1, 2)
     d2 = F.instance_norm(F.leaky_relu(self.d2(x), 0.2))
     x = F.max_pool2d(d2, 2)
     d3 = F.instance_norm(F.leaky_relu(self.d3(x), 0.2))
     encoder = self.enmaxpool(d3)
     #d4 = F.instance_norm(F.leaky_relu(self.d4(d3), 0.2))
     #x = F.upsample(d3, scale_factor=2)
     x = self.up1(encoder)
     #x = nn.ZeroPad2d((1,0,1,0))(x)
     x = self.u1(x)
     x = F.leaky_relu(x, 0.2)
     x = F.instance_norm(x)
     u1 = torch.cat((x, d3), 1)
     x = self.up1(u1)
     #x = F.upsample(u1, scale_factor=2)
     x = self.u2(x)
     x = F.leaky_relu(x, 0.2)
     x = F.instance_norm(x)
     u2 = torch.cat((x, d2), 1)
     x = self.up1(u2)
     #x = F.upsample(u2, scale_factor=2)
     #x = nn.ZeroPad2d((1,0,1,0))(x)
     x = self.u3(x)
     x = F.leaky_relu(x, 0.2)
     x = F.instance_norm(x)
     u3 = torch.cat((x, d1), 1)
     #x = F.upsample(x, scale_factor=2)
     #x = nn.ZeroPad2d((1,0,1,0))(x)
     #x = self.u4(x)
     # x = F.upsample(u3, scale_factor=2)
     #x = nn.ZeroPad2d((2,1,2,1))(x)
     x = self.output(u3)
     x = F.relu(x)
     return x
    def forward(self, x):
        residual = x
        residual = self.avgPool(residual)
        if hasattr(self, "adaDim"):
            # fill_to_out_channels = self.adaDim(residual)
            # residual = torch.cat((residual, fill_to_out_channels), dim=1)
            residual = self.adaDim(residual)

        out = F.instance_norm(x)
        out = self.relu(out)
        out = self.conv1(out)
        out = F.instance_norm(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = F.instance_norm(out)
        out = self.relu(out)
        out = self.conv3(out)
        out = F.instance_norm(out)
        out = self.relu(out)
        out = self.avgPool(out)
        out = self.conv4(out)

        out += residual
        return out
コード例 #14
0
 def forward(self, x):
     self.eval()
     x = F.pad(F.instance_norm(x), (15, 15, 15, 15), 'reflect')
     x = F.relu(F.instance_norm(self.c1(x)), inplace=True)
     x = F.relu(F.instance_norm(self.c2(x)), inplace=True)
     x = F.relu(F.instance_norm(self.c3(x)), inplace=True)
     x = F.relu(F.instance_norm(self.c4(x)), inplace=True)
     x = F.relu(F.instance_norm(self.c5(x)), inplace=True)
     self.train()
     return x
コード例 #15
0
    def forward(self, input, ConInfor):
        self._check_input_dim(input)
        b, c = input.size(0), input.size(1)
        tarBias = self.ConBias(ConInfor).view(b, c, 1, 1)
        out = F.instance_norm(input, self.running_mean, self.running_var, None,
                              None, self.training
                              or not self.track_running_stats, self.momentum,
                              self.eps)

        if self.affine:
            bias = self.bias.repeat(b).view(b, c, 1, 1)
            weight = self.weight.repeat(b).view(b, c, 1, 1)
            return (out.view(b, c,
                             *input.size()[2:]) + tarBias) * weight + bias
        else:
            return out.view(b, c, *input.size()[2:]) + tarBias
コード例 #16
0
    def forward(self, input, label):
        # self._check_input_dim(input)
        if label >= self.num_labels:
            raise ValueError(
                'Expected label to be < than {} but got {}'.format(
                    self.num_labels, label))
        w = self.weight
        b = self.bias
        if self.affine:
            w = self.weight[label, :]
            b = self.bias[label, :]

        return F.instance_norm(input, self.running_mean, self.running_var, w,
                               b, self.training
                               or not self.track_running_stats, self.momentum,
                               self.eps)
コード例 #17
0
 def forward(self, input):
     if self.transpose_last:
         input = input.transpose(1, 2)
     output = F.instance_norm(
         input.float(),
         running_mean=self.running_mean,
         running_var=self.running_var,
         weight=self.weight.float() if self.weight is not None else None,
         bias=self.bias.float() if self.bias is not None else None,
         use_input_stats=self.training or not self.track_running_stats,
         momentum=self.momentum,
         eps=self.eps,
     )
     if self.transpose_last:
         output = output.transpose(1, 2)
     return output.type_as(input)
コード例 #18
0
    def test_no_quant(self):

        quant_instancenorm_object = quant_instancenorm.QuantInstanceNorm1d(
            NUM_CHANNELS, affine=True)
        quant_instancenorm_object.input_quantizer.disable()

        test_input = torch.randn(8, NUM_CHANNELS, 128)

        out1 = quant_instancenorm_object(test_input)
        out2 = F.instance_norm(test_input,
                               quant_instancenorm_object.running_mean,
                               quant_instancenorm_object.running_var,
                               quant_instancenorm_object.weight,
                               quant_instancenorm_object.bias)
        np.testing.assert_array_equal(out1.detach().cpu().numpy(),
                                      out2.detach().cpu().numpy())
コード例 #19
0
def function_hook(input, *args, **kwargs):
    class InstanceNorm(nn.Module):
        def __init__(self, running_mean, running_var, weight, bias,
                     use_input_stats, momentum, eps):
            super().__init__()
            self.running_mean = running_mean
            self.running_var = running_var
            self.weight = weight
            self.bias = bias
            self.use_input_stats = use_input_stats
            self.momentum = momentum
            self.eps = eps
            self.dims = input.dim()

    output = F.instance_norm(input.tensor(), *args, **kwargs)
    return forward_hook(InstanceNorm(*args, **kwargs), (input, ), output)
コード例 #20
0
    def test_fake_quant_per_tensor(self):

        quant_instancenorm_object = quant_instancenorm.QuantInstanceNorm1d(
            NUM_CHANNELS, affine=True, quant_desc_input=QuantDescriptor())

        test_input = torch.randn(8, NUM_CHANNELS, 128)
        quant_input = tensor_quant.fake_tensor_quant(
            test_input, torch.max(torch.abs(test_input)))

        out1 = quant_instancenorm_object(test_input)
        out2 = F.instance_norm(quant_input,
                               quant_instancenorm_object.running_mean,
                               quant_instancenorm_object.running_var,
                               quant_instancenorm_object.weight,
                               quant_instancenorm_object.bias)
        np.testing.assert_array_equal(out1.detach().cpu().numpy(),
                                      out2.detach().cpu().numpy())
コード例 #21
0
    def forward(self, x, y, style=None):
        # Calculate class-conditional gains and biases
        if self.use_dog_cnt and not self.g_shared:
            gain = (1 + self.gain(y[:, 0])).view(y.size(0), -1, 1, 1)
            bias = self.bias(y[:, 0]).view(y.size(0), -1, 1, 1)

            gain_dog_cnt = (1 + self.gain_dog_cnt(y[:, 1])).view(
                y.size(0), -1, 1, 1)
            bias_dog_cnt = self.bias_dog_cnt(y[:, 1]).view(y.size(0), -1, 1, 1)
        else:
            gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1)
            bias = self.bias(y).view(y.size(0), -1, 1, 1)

        # If using my batchnorm
        if self.mybn or self.cross_replica:
            if style is None:
                return self.bn(x, gain=gain, bias=bias)
            else:
                out = self.bn(x, gain=gain, bias=bias)
                style = self.style(style).unsqueeze(2).unsqueeze(3)
                gamma, beta = style.chunk(2, 1)
                out = gamma * out + beta
                return out
        # else:
        else:
            if self.norm_style == 'bn':
                out = F.batch_norm(x, self.stored_mean, self.stored_var, None,
                                   None, self.training, 0.1, self.eps)
            elif self.norm_style == 'in':
                out = F.instance_norm(x, self.stored_mean, self.stored_var,
                                      None, None, self.training, 0.1, self.eps)
            elif self.norm_style == 'gn':
                out = groupnorm(x, self.normstyle)
            elif self.norm_style == 'nonorm':
                out = x

            if not self.no_conditional:
                out = out * gain + bias
            if self.use_dog_cnt and not self.g_shared:
                out = out * gain_dog_cnt + bias_dog_cnt
            if style is not None:
                style = self.style(style).unsqueeze(2).unsqueeze(3)
                gamma, beta = style.chunk(2, 1)
                out = gamma * out + beta
            return out
コード例 #22
0
 def orient_features(self, feat_map):
     B, C, H, W = feat_map.shape
     ori_maps = self.orient(feat_map)
     ori_maps = ori_maps / ori_maps.norm(2, dim=1).unsqueeze(1)
     score_maps = []
     for i, scale in enumerate(self.scale_factors):
         resized_feat_map = NF.interpolate(feat_map,
                                           scale_factor=1 / scale,
                                           mode='bilinear')
         resized_score_map = self.scale_convs[i](resized_feat_map)
         normalized_resized_score_map = NF.instance_norm(resized_score_map,
                                                         eps=1e-3)
         normalized_score_map = NF.interpolate(normalized_resized_score_map,
                                               size=(H, W),
                                               mode='bilinear')
         score_maps.append(normalized_score_map)
     scale_maps = torch.cat(score_maps, 1)
     return scale_maps, ori_maps
コード例 #23
0
ファイル: layer.py プロジェクト: BARANCE/StyleGanPytorch
    def forward(self, x: torch.Tensor, w: torch.Tensor) -> torch.Tensor:
        """順方向伝搬
        AdaINを反映する.

        Args:
            x (torch.Tensor): コンテンツ画像を表すテンソル
            w (torch.Tensor): スタイル画像を表すテンソル

        Returns:
            torch.Tensor: AdaINを適用したテンソル
        """
        x = F.instance_norm(x, eps=self.epsilon)

        # scale
        scale = self.scale_transform(w)
        bias = self.bias_transform(w)

        return scale * x + bias
コード例 #24
0
ファイル: cn_net.py プロジェクト: sunbirddy/consac
    def masked_instance_norm(self, data, mask):

        B = data.size(0)

        num_elements = mask.sum(-1)

        new_data_batch = []
        for bi in range(B):
            new_data_a = F.instance_norm(data[bi, :, :num_elements[bi]])
            if num_elements[bi] < data.size(2):
                new_data_b = data[bi, :, num_elements[bi]:]
                new_data = torch.cat([new_data_a, new_data_b], dim=1)
            else:
                new_data = new_data_a
            new_data_batch += [new_data]

        data = torch.stack(new_data_batch, dim=0)

        return data
コード例 #25
0
 def forward(self, x):  # input: 3 × 256 × 256
     o1 = self.refpad(x)
     o2 = F.relu(F.instance_norm(self.conv1(o1)))
     o3 = F.relu(F.instance_norm(self.conv2(o2)))
     o4 = F.relu(F.instance_norm(self.conv3(o3)))
     o5 = self.block1(o4)
     o6 = self.block2(o5)
     o7 = self.block3(o6)
     o8 = self.block4(o7)
     o9 = self.block5(o8)
     o10 = F.relu(F.instance_norm(self.conv4(o9)))
     o11 = F.relu(F.instance_norm(self.conv5(o10)))
     o12 = torch.tanh(F.instance_norm(self.conv6(o11)))
     return o12
コード例 #26
0
ファイル: layers.py プロジェクト: yqGANs/LOGAN
 def forward(self, x, y):
     # Calculate class-conditional gains and biases
     gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1)
     bias = self.bias(y).view(y.size(0), -1, 1, 1)
     # If using my batchnorm
     if self.mybn:
         return self.bn(x, gain=gain, bias=bias)
     # else:
     else:
         if self.norm_style == 'bn':
             out = F.batch_norm(x, self.stored_mean, self.stored_var, None, None,
                                self.training, 0.1, self.eps)
         elif self.norm_style == 'in':
             out = F.instance_norm(x, self.stored_mean, self.stored_var, None, None,
                                   self.training, 0.1, self.eps)
         elif self.norm_style == 'gn':
             out = groupnorm(x, self.normstyle)
         elif self.norm_style == 'nonorm':
             out = x
         return out * gain + bias
コード例 #27
0
ファイル: da.py プロジェクト: zhmd/salad
    def forward(self, x, y):
        N = y.size()[0]

        sd = y.view(N, self.n_channels, -1).std(dim=-1)
        mu = y.view(N, self.n_channels, -1).mean(dim=-1)

        x_ = F.instance_norm(
            x,
            running_mean=None,
            running_var=None,
            weight=None,  #(self.eps + var)**.5,
            bias=None,
            use_input_stats=True,
            momentum=0.,
            eps=self.eps)

        x_ = x_ * sd.unsqueeze(-1).unsqueeze(-1) + mu.unsqueeze(-1).unsqueeze(
            -1)

        return x_
コード例 #28
0
    def forward(self, x):

        # calculate act
        act = self.sg_conv(x)

        b, c, w, h = x.shape

        # loss
        tmp = act + 1e-3
        tmp = tmp.reshape((-1, w*h))
        tmp = F.instance_norm(tmp)
        tmp = tmp.reshape((-1, self.out_channels, w*h))
        tmp = tmp.permute(1, 0, 2).reshape(self.out_channels, -1)

        co_matrix = torch.matmul(tmp, tmp.t()).reshape((1, self.out_channels**2))
        co_matrix /= self.batch_size

        loss = torch.sum((co_matrix-self.gt)*(co_matrix-self.gt)*0.001, dim=1).repeat(self.batch_size)
        self.loss = loss/((self.out_channels/512.0)**2)
        return act
コード例 #29
0
 def forward(self, inpt, label):
     # self._check_input_dim(inpt)
     ins = F.instance_norm(inpt, self.running_mean, self.running_var, None,
                           None, self.training
                           or not self.track_running_stats, self.momentum,
                           self.eps)
     if torch.max(label) >= self.num_labels:
         raise ValueError(
             'Expected label to be < than {} but got {}'.format(
                 self.num_labels, label))
     w = self.weight
     b = self.bias
     if self.affine:
         w = self.weight[label].view(
             inpt.size(0), self.num_features).unsqueeze(2).unsqueeze(3)
         b = self.bias[label].view(
             inpt.size(0), self.num_features).unsqueeze(2).unsqueeze(3)
         return ins * w + b
     else:
         return ins
コード例 #30
0
    def forward(self, x, w, w1=None, ratio=[0,1], interp_mode='inter'):
        x = F.instance_norm(x, eps=self.epsilon)
    
        scale = self.scale_transform(w)
        bias = self.bias_transform(w)

        # Style mixing
        if w1 is not None:
            scale1 = self.scale_transform(w1)
            bias1 = self.bias_transform(w1)
            if interp_mode == 'inter':
                scale = (ratio[0] * scale + ratio[1] * scale1) / (ratio[0]+ratio[1])
                bias = (ratio[0] * bias + ratio[1] * bias1) / (ratio[0]+ratio[1])
            elif interp_mode == 'extra':
                scale = (ratio[0] * scale - ratio[1] * scale1) / (ratio[0]-ratio[1])
                bias = (ratio[0] * bias - ratio[1] * bias1) / (ratio[0]-ratio[1])
            else:
                raise ValueError(f'Invalid interpolation mode {interp_mode}.')
    
        return scale * x + bias