示例#1
0
class WN_Conv2d(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True,
                 train_scale=False, init_stdv=1.0):
        super(WN_Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
        if train_scale:
            self.weight_scale = Parameter(torch.Tensor(out_channels))
        else:
            self.register_buffer('weight_scale', torch.Tensor(out_channels))

        self.train_scale = train_scale
        self.init_mode = False
        self.init_stdv = init_stdv

        self._reset_parameters()

    def _reset_parameters(self):
        self.weight.data.normal_(std=0.05)
        if self.bias is not None:
            self.bias.data.zero_()
        if self.train_scale:
            self.weight_scale.data.fill_(1.)
        else:
            self.weight_scale.fill_(1.)

    def forward(self, input):
        if self.train_scale:
            weight_scale = self.weight_scale
        else:
            weight_scale = Variable(self.weight_scale)
        # normalize weight matrix and linear projection [out x in x h x w]
        # for each output dimension, normalize through (in, h, w) = (1, 2, 3) dims
        norm_weight = self.weight * (
            weight_scale[:, None, None, None] / torch.sqrt(
                (self.weight ** 2).sum(3, keepdim=True).sum(2, keepdim=True).sum(1, keepdim=True) + 1e-6)).expand_as(
            self.weight)
        activation = F.conv2d(input, norm_weight, bias=None,
                              stride=self.stride, padding=self.padding,
                              dilation=self.dilation, groups=self.groups)

        if self.init_mode == True:
            mean_act = activation.mean(3).mean(2).mean(0).squeeze()
            activation = activation - mean_act[None, :, None, None].expand_as(activation)

            inv_stdv = self.init_stdv / torch.sqrt(
                (activation ** 2).mean(3, keepdim=True).mean(2, keepdim=True).mean(0, keepdim=True) + 1e-6).squeeze()
            activation = activation * inv_stdv[None, :, None, None].expand_as(activation)

            if self.train_scale:
                self.weight_scale.data = self.weight_scale.data * inv_stdv.data
            else:
                self.weight_scale = self.weight_scale * inv_stdv.data
            self.bias.data = - mean_act.data * inv_stdv.data

        else:
            if self.bias is not None:
                activation = activation + self.bias[None, :, None, None].expand_as(activation)

        return activation
class NoisyLinear(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super(NoisyLinear, self).__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.weight = torch.Tensor(out_features, in_features)
        self.weight_epsilon = torch.Tensor(out_features, in_features)
        self.weight_mu = Parameter(torch.Tensor(out_features, in_features))
        self.weight_sigma = Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = torch.Tensor(out_features)
            self.bias_epsilon = torch.Tensor(out_features)
            self.bias_mu = Parameter(torch.Tensor(out_features))
            self.bias_sigma = Parameter(torch.Tensor(out_features))
        else:
            self.bias = None
            self.bias_epsilon = None
            self.register_parameter('bias_mu', None)
            self.register_parameter('bias_sigma', None)
        self.reset_parameters()
        self.sampled = False

    def sample(self):
        if self.training:
            self.weight_epsilon.normal_()
            self.weight = self.weight_epsilon.mul(self.weight_sigma).add_(
                self.weight_mu)
            if self.bias is not None:
                self.bias_epsilon.normal_()
                self.bias = self.bias_epsilon.mul(self.bias_sigma).add_(
                    self.bias_mu)
        else:
            self.weight = self.weight_mu.detach()
            if self.bias is not None:
                self.bias = self.bias_mu.detach()
        self.sampled = True

    def reset_parameters(self):
        stdv = math.sqrt(3.0 / self.weight.size(1))
        self.weight_mu.uniform_(-stdv, stdv)
        self.weight_sigma.fill_(0.017)
        if self.bias is not None:
            self.bias_mu.uniform_(-stdv, stdv)
            self.bias_sigma.fill_(0.017)

    def forward(self, input):
        if not self.sampled:
            self.sample()
        return F.linear(input, self.weight, self.bias)

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'
示例#3
0
class WN_Linear(nn.Linear):
    def __init__(self, in_features, out_features, bias=True, train_scale=False, init_stdv=1.0):
        super(WN_Linear, self).__init__(in_features, out_features, bias=bias)
        if train_scale:
            self.weight_scale = Parameter(torch.ones(self.out_features))
        else:
            self.register_buffer('weight_scale', torch.Tensor(out_features))

        self.train_scale = train_scale
        self.init_mode = False
        self.init_stdv = init_stdv

        self._reset_parameters()

    def _reset_parameters(self):
        self.weight.data.normal_(0, std=0.05)
        if self.bias is not None:
            self.bias.data.zero_()
        if self.train_scale:
            self.weight_scale.data.fill_(1.)
        else:
            self.weight_scale.fill_(1.)

    def forward(self, input):
        if self.train_scale:
            weight_scale = self.weight_scale
        else:
            weight_scale = Variable(self.weight_scale)

        # normalize weight matrix and linear projection
        norm_weight = self.weight * (
            weight_scale.unsqueeze(1) / torch.sqrt((self.weight ** 2).sum(1, keepdim=True) + 1e-6)).expand_as(
            self.weight)
        activation = F.linear(input, norm_weight)

        if self.init_mode == True:
            mean_act = activation.mean(0).squeeze(0)
            activation = activation - mean_act.expand_as(activation)

            inv_stdv = self.init_stdv / torch.sqrt((activation ** 2).mean(0, keepdim=True) + 1e-6).squeeze(0)
            activation = activation * inv_stdv.expand_as(activation)

            if self.train_scale:
                self.weight_scale.data = self.weight_scale.data * inv_stdv.data
            else:
                self.weight_scale = self.weight_scale * inv_stdv.data
            self.bias.data = - mean_act.data * inv_stdv.data

        else:
            if self.bias is not None:
                activation = activation + self.bias.expand_as(activation)

        return activation
示例#4
0
class TestModule(Module):
    def __init__(self, input_num):
        super(TestModule, self).__init__()
        self.param = Parameter(torch.Tensor(1, input_num))
        self.reset_parameters()

    def reset_parameters(self):
        # fake data
        with torch.no_grad():
            self.param.fill_(0)
            self.param[0, 0] = 0.25111
            self.param[0, 1] = 0.5

    def forward(self, input):
        return F.linear(input, self.param, None)
class Embedding(nn.Module):
    def __init__(self,
                 n_tokens,
                 latent_dim,
                 padding_idx=None,
                 init='truncnorm'):
        super(Embedding, self).__init__()
        self.n_tokens = n_tokens
        self.latent_dim = latent_dim
        if padding_idx is not None:
            if padding_idx > 0:
                assert padding_idx < self.n_tokens, \
                    'padding_idx must be within n_tokens'
            elif padding_idx < 0:
                assert padding_idx >= -self.n_tokens, \
                    'padding_idx must be within n_tokens'
        self.padding_idx = padding_idx
        self.init = init

        self.reset_parameters()

    def reset_parameters(self):
        with torch.no_grad():
            if self.init == 'truncnorm':
                t = 1. / (self.n_tokens**(1 / 2))
                weights = truncnorm.rvs(-t,
                                        t,
                                        size=[self.n_tokens, self.latent_dim])
                self.weights = Parameter(torch.tensor(weights).float())
            elif self.init == 'zeros':
                self.weights = Parameter(
                    torch.Tensor(self.n_tokens, self.latent_dim))
                self.weights.fill_(1.0)

        if self.padding_idx is not None:
            with torch.no_grad():
                self.weights[self.padding_idx].zero_()

    def forward(self, x):
        x = F.embedding(x, self.weights, padding_idx=self.padding_idx)

        return x
示例#6
0
文件: BN.py 项目: wcf1065948474/COCO
class _GBN(nn.Module):
    __constants__ = [
        'track_running_stats', 'momentum', 'eps', 'weight', 'bias',
        'running_mean', 'running_var', 'num_batches_tracked', 'num_features',
        'affine'
    ]

    def __init__(self,
                 opt,
                 num_features,
                 eps=1e-5,
                 momentum=0.01,
                 affine=True,
                 track_running_stats=True):
        super(_GBN, self).__init__()
        self.opt = opt
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.track_running_stats = track_running_stats
        if self.affine:
            self.weight = Parameter(
                torch.Tensor(self.opt.micro_in_macro, 1, num_features, 1, 1))
            self.bias = Parameter(
                torch.Tensor(self.opt.micro_in_macro, 1, num_features, 1, 1))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)
        if self.track_running_stats:
            self.running_mean = Parameter(torch.Tensor(self.opt.micro_in_macro,
                                                       1, num_features, 1, 1),
                                          requires_grad=False)
            self.running_var = Parameter(torch.Tensor(self.opt.micro_in_macro,
                                                      1, num_features, 1, 1),
                                         requires_grad=False)
        else:
            self.register_parameter('running_mean', None)
            self.register_parameter('running_var', None)
        self.reset_parameters()

    def reset_running_stats(self):
        if self.track_running_stats:
            self.running_mean.zero_()
            self.running_var.fill_(1)

    def reset_parameters(self):
        self.reset_running_stats()
        if self.affine:
            init.ones_(self.weight)
            init.zeros_(self.bias)

    def _check_input_dim(self, input):
        raise NotImplementedError

    def forward(self, input):
        self._check_input_dim(input)
        output = self.g_b_n(input, self.running_mean, self.running_var,
                            self.weight, self.bias)
        return output

    def extra_repr(self):
        return '{num_features}, eps={eps}, momentum={momentum}, affine={affine}, ' \
               'track_running_stats={track_running_stats}'.format(**self.__dict__)

    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
        super(_GBN,
              self)._load_from_state_dict(state_dict, prefix, local_metadata,
                                          strict, missing_keys,
                                          unexpected_keys, error_msgs)

    def g_b_n(self, input, running_mean, running_var, weight, beta):
        N, C, H, W = input.size()
        G = self.opt.micro_in_macro
        input = input.view(G, N // G, C, H, W)
        mean = torch.mean(input, (1, 3, 4), keepdim=True)
        var = torch.var(input, (1, 3, 4), keepdim=True)

        if self.training:
            running_mean.data = running_mean.data * (
                1 - self.momentum) + mean * self.momentum
            running_var.data = running_var.data * (
                1 - self.momentum) + var * self.momentum
            X_hat = (input - mean) / torch.sqrt(var + self.eps)
        else:
            X_hat = (input - running_mean) / torch.sqrt(running_var + self.eps)
        X_hat = X_hat * weight + beta
        output = X_hat.view(N, C, H, W)
        return output