Beispiel #1
0
class RiemannianLayer(nn.Module):
    def __init__(self, in_features, out_features, manifold, over_param, weight_norm):
        super(RiemannianLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.manifold = manifold

        self._weight = Parameter(torch.Tensor(out_features, in_features))
        self.over_param = over_param
        self.weight_norm = weight_norm
        if self.over_param:
            self._bias = ManifoldParameter(torch.Tensor(out_features, in_features), manifold=manifold)
        else:
            self._bias = Parameter(torch.Tensor(out_features, 1))
        self.reset_parameters()

    @property
    def weight(self):
        return self.manifold.transp0(self.bias, self._weight) # weight \in T_0 => weight \in T_bias

    @property
    def bias(self):
        if self.over_param:
            return self._bias
        else:
            return self.manifold.expmap0(self._weight * self._bias) # reparameterisation of a point on the manifold

    def reset_parameters(self):
        init.kaiming_normal_(self._weight, a=math.sqrt(5))
        fan_in, _ = init._calculate_fan_in_and_fan_out(self._weight)
        bound = 4 / math.sqrt(fan_in)
        init.uniform_(self._bias, -bound, bound)
        if self.over_param:
            with torch.no_grad(): self._bias.set_(self.manifold.expmap0(self._bias))
class RiemannLayer(nn.Module):
	
	def __init__(self, in_features, out_features, manifold, over_param):
		super(RiemannLayer, self).__init__()
		self.in_features = in_features
		self.out_features = out_features
		self.weight = Parameter(torch.Tensor(out_features, in_features))
		self.over_param = over_param
		if self.over_param:
			self._bias = ManifoldParameter(torch.Tensor(out_features, in_features), manifold=manifold)
		else:
			self._bias = Parameter(torch.Tensor(out_features, 1))
		self.manifold = manifold
		self.reset_parameters()
	
	@property
	def weight(self):
		return self.manifold.transp0(self._bias, self._weight)
	
	@property
	def bias(self):
		return self.manifold.expmap0(self.weight.mul(self._bias))
    
	def reset_parameters(self):
		nn.init.kaiming_normal_(self.weight, a =math.sqrt(5))
		if self.bias is not None:
			fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
			bound = 4 / math.sqrt(fan_in)
			nn.init.uniform_(self.bias, -bound, bound)
			if self.over_param:
				with torch.no_grad(): self._bias.set_(self.manifold.expmap0(self._bias))
Beispiel #3
0
class WSConv2d(nn.Conv2d):
    def __init__(self,
                 in_planes,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
                 multiplier=1.0,
                 rep_dim=1,
                 repeat_weight=True,
                 use_coeff=False):
        if rep_dim == 0:
            # this is repeat along the channel dim (dimension 0 of the weights tensor)
            super(WSConv2d,
                  self).__init__(int(in_planes),
                                 int(np.ceil(out_channels / multiplier)),
                                 kernel_size,
                                 stride=stride,
                                 padding=padding,
                                 dilation=dilation,
                                 groups=groups,
                                 bias=bias)
            self.rep_time = int(
                np.ceil(1. * out_channels / self.weight.shape[0]))
        elif rep_dim == 1:
            # this is to repeat along the filter dim(dimension 1 of the weights tensor)
            super(WSConv2d,
                  self).__init__(int(np.ceil(in_planes / multiplier)),
                                 int(out_channels),
                                 kernel_size,
                                 stride=stride,
                                 padding=padding,
                                 dilation=dilation,
                                 groups=groups,
                                 bias=bias)
            self.rep_time = int(np.ceil(1. * in_planes / self.weight.shape[1]))

        self.in_planes = in_planes
        self.out_channels_ori = out_channels
        self.groups = groups
        self.multiplier = multiplier
        self.rep_dim = rep_dim
        self.repeat_weight = repeat_weight
        self.use_coeff = use_coeff
        # print(self.weight.shape)
        # import pdb; pdb.set_trace()

        self.conv1_stride_lr_1 = nn.Conv2d(in_planes,
                                           in_planes,
                                           kernel_size=3,
                                           stride=2,
                                           padding=0,
                                           bias=False)
        self.conv1_stride_lr_2 = nn.Conv2d(in_planes,
                                           self.rep_time,
                                           kernel_size=1,
                                           stride=1,
                                           padding=0,
                                           bias=False)
        self.coefficient = Parameter(torch.Tensor(self.rep_time),
                                     requires_grad=False)
        self.reuse = False
        self.coeff_grad = None

    def generate_share_weight(self,
                              base_weight,
                              rep_num,
                              coeff,
                              nchannel,
                              dim=0):
        ''' sample weights from base weight'''
        # pdb.set_trace()
        if rep_num == 1:
            return base_weight
        new_weight = []
        for i in range(rep_num):
            if dim == 0:
                new_weight_temp = torch.cat(
                    [base_weight[1:, :, :, :], base_weight[0:1, :, :, :]],
                    dim=0) * (1 - coeff[i])
            else:
                new_weight_temp = torch.cat(
                    [base_weight[:, 1:, :, :], base_weight[:, 0:1, :, :]],
                    dim=1) * (1 - coeff[i])
            new_weight.append(base_weight * coeff[i] + new_weight_temp)
        out = torch.cat(new_weight, dim=dim)

        if dim == 0:
            return out[:nchannel, :, :, :]
        else:
            return out[:, :nchannel, :, :]

    def forward(self, x):
        """
            same padding as efficientnet tf version
        """
        ih, iw = x.size()[-2:]
        kh, kw = self.weight.size()[-2:]
        sh, sw = self.stride
        oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
        pad_h = max(
            (oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih,
            0)
        pad_w = max(
            (ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw,
            0)
        if pad_h > 0 or pad_w > 0:
            x = F.pad(x, [
                pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2
            ])

        if self.use_coeff:
            if self.training:
                # set reuse to True for coefficient sharing
                if not self.reuse:
                    lr_conv1 = self.conv1_stride_lr_1(x)
                    # pdb.set_trace()
                    lr_conv1 = self.conv1_stride_lr_2(lr_conv1)
                    lr_conv1 = F.adaptive_avg_pool2d(lr_conv1, (1, 1))[:, :, 0,
                                                                       0]

                    self.coefficient.set_(
                        F.normalize(torch.mean(lr_conv1, 0),
                                    dim=0).clone().detach())
                    # pdb.set_trace()
                    self.coeff_grad = F.normalize(torch.mean(lr_conv1, 0),
                                                  dim=0)

                if self.rep_dim == 0:
                    if self.repeat_weight:
                        out = F.conv2d(
                            x,
                            self.generate_share_weight(self.weight,
                                                       self.rep_time,
                                                       self.coeff_grad,
                                                       self.out_channels_ori,
                                                       dim=0))
                    else:
                        out_tmp = F.conv2d(x, self.weight)
                        out = self.generate_share_feature(
                            self.rep_time, out_tmp,
                            F.normalize(torch.mean(lr_conv1, 0), dim=0),
                            self.out_channels_ori)
                        # out = F.conv2d(x, self.generate_share_feature(self.rep_time, F.normalize(torch.mean(lr_conv1, 0), dim = 0)))
                else:
                    if self.repeat_weight:

                        out = F.conv2d(
                            x,
                            self.generate_share_weight(self.weight,
                                                       self.rep_time,
                                                       self.coeff_grad,
                                                       x.shape[1],
                                                       dim=1))
                    else:
                        out_tmp = self.feature_wrapper(
                            self.rep_time, x,
                            F.normalize(torch.mean(lr_conv1, 0), dim=0),
                            self.out_channels_ori)
                        out = F.conv2d(out_tmp, self.weight)

            else:
                if self.rep_dim == 0:
                    if self.repeat_weight:
                        out = F.conv2d(
                            x,
                            self.generate_share_weight(
                                self.weight,
                                self.rep_time,
                                self.coefficient.detach(),
                                self.out_channels_ori,
                                dim=0))
                    else:
                        out_tmp = F.conv2d(x, self.weight)
                        out = self.generate_share_feature(
                            self.rep_time, out_tmp, self.coefficient.detach(),
                            self.out_channels_ori)
                        # out = F.conv2d(x, self.generate_share_feature(self.rep_time, self.coefficient.detach()))
                else:
                    if self.repeat_weight:
                        out = F.conv2d(
                            x,
                            self.generate_share_weight(
                                self.weight,
                                self.rep_time,
                                self.coefficient.detach(),
                                x.shape[1],
                                dim=1))
                    else:
                        out_tmp = self.feature_wrapper(
                            self.rep_time, out_tmp, self.coefficient.detach(),
                            self.out_channels_ori)

                        out = F.conv2d(out_tmp, self.weight)
        else:
            # print("use_coeff == False")
            if self.rep_dim == 0:
                out = F.conv2d(
                    x,
                    self.weight.repeat([self.rep_time, 1, 1,
                                        1])[:self.out_channels_ori, :, :, :],
                    None, 1)
            else:
                out = F.conv2d(
                    x,
                    self.weight.repeat([1, self.rep_time, 1,
                                        1])[:, :x.shape[1], :, :], None, 1)
        return out
Beispiel #4
0
class NESConv2d(nn.Conv2d):
    def __init__(self,
                 in_planes,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
                 multiplier=1.0,
                 spatial_multiplier=1.,
                 rep_dim=1,
                 repeat_weight=True,
                 use_coeff=False):
        if rep_dim == 0:
            # this is repeat along the channel dim (dimension 0 of the weights tensor)
            super(NESConv2d,
                  self).__init__(int(in_planes),
                                 int(np.ceil(out_channels / multiplier)),
                                 int(kernel_size * spatial_multiplier),
                                 stride=stride,
                                 padding=padding,
                                 dilation=dilation,
                                 groups=groups,
                                 bias=bias)
            self.rep_time = int(
                np.ceil(1. * out_channels / self.weight.shape[0]))
        elif rep_dim == 1:
            # this is to repeat along the filter dim(dimension 1 of the weights tensor)
            super(NESConv2d,
                  self).__init__(int(np.ceil(in_planes / multiplier)),
                                 int(out_channels),
                                 int(kernel_size * spatial_multiplier),
                                 stride=stride,
                                 padding=padding,
                                 dilation=dilation,
                                 groups=groups,
                                 bias=bias)
            self.rep_time = int(np.ceil(1. * in_planes / self.weight.shape[1]))

        self.in_planes = in_planes
        self.out_channels_ori = out_channels
        self.groups = groups
        self.multiplier = multiplier
        self.spatial_multiplier = spatial_multiplier
        # specify the range for the w and h direction
        self.kernel_size = kernel_size
        self.w_wange = kernel_size * (spatial_multiplier - 1)
        self.h_wange = kernel_size * (spatial_multiplier - 1)

        self.rep_dim = rep_dim
        self.repeat_weight = repeat_weight
        self.use_coeff = use_coeff
        # print(self.weight.shape)
        # import pdb; pdb.set_trace()
        if spatial_multiplier > 1:
            out_num = int(self.rep_time * 3)
            self.conv1_stride_lr_1 = nn.Conv2d(in_planes,
                                               in_planes,
                                               kernel_size=3,
                                               stride=2,
                                               padding=0,
                                               bias=False)
            self.bn1 = nn.BatchNorm2d(in_planes)
            self.relu = nn.ReLU6(inplace=True)
            self.conv1_stride_lr_2 = nn.Conv2d(in_planes,
                                               out_num,
                                               kernel_size=1,
                                               stride=1,
                                               padding=0,
                                               bias=False)
            self.bn2 = nn.BatchNorm2d(out_num)
            self.coefficient = Parameter(torch.Tensor(out_num),
                                         requires_grad=False)
        else:
            out_num = int(self.rep_time)
            self.conv1_stride_lr_1 = nn.Conv2d(in_planes,
                                               in_planes,
                                               kernel_size=3,
                                               stride=2,
                                               padding=0,
                                               bias=False)
            self.bn1 = nn.BatchNorm2d(in_planes)
            self.relu = nn.ReLU6(inplace=True)
            self.conv1_stride_lr_2 = nn.Conv2d(in_planes,
                                               out_num,
                                               kernel_size=1,
                                               stride=1,
                                               padding=0,
                                               bias=False)
            self.bn2 = nn.BatchNorm2d(out_num)
            self.coefficient = Parameter(torch.Tensor(out_num),
                                         requires_grad=False)
        self.reuse = False
        self.coeff_grad = None

    def generate_share_weight(self,
                              base_weight,
                              rep_num,
                              coeff,
                              nchannel,
                              dim=0):
        ''' sample weights from base weight'''
        # pdb.set_trace()
        if rep_num == 1:
            return base_weight
        new_weight = []
        for i in range(rep_num):
            w_idx = coeff(i * 3 + 1) * self.w_wange
            start_idx_w = int(w_idx)
            end_idx_w = start_idx_w + self.kernel_size
            w_frac = w_idx - start_idx_w

            h_idx = coeff(i * 3 + 2) * self.h_wange
            start_idx_h = int(h_idx)
            end_idx_h = start_idx_h + self.kernel_size
            h_frac = h_idx - start_idx_h

            new_weight_temp = torch.cat(
                [base_weight[:, :, :, :], base_weight[:, :, :, :]],
                dim=2)[:, :, start_idx_w:end_idx_w, :] * (
                    1 - w_frac) + base_weight * w_frac
            new_weight_temp = torch.cat(
                [base_weight[:, :, :, :], base_weight[:, :, :, :]],
                dim=3)[:, :, :, start_idx_h:end_idx_h] * (
                    1 - h_frac) + base_weight * h_frac
            if dim == 0:
                new_weight_temp = torch.cat(
                    [base_weight[1:, :, :, :], base_weight[0:1, :, :, :]],
                    dim=0) * (1 - coeff[int(i * 3)])
            else:
                new_weight_temp = torch.cat(
                    [base_weight[:, 1:, :, :], base_weight[:, 0:1, :, :]],
                    dim=1) * (1 - coeff[i])
            new_weight.append(base_weight * coeff[int(i * 3)] +
                              new_weight_temp)
        out = torch.cat(new_weight, dim=dim)

        if dim == 0:
            return out[:nchannel, :, :, :]
        else:
            return out[:, :nchannel, :, :]

    def forward(self, x):
        ih, iw = x.size()[-2:]
        kh, kw = self.weight.size()[-2:]
        sh, sw = self.stride
        oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
        pad_h = max(
            (oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih,
            0)
        pad_w = max(
            (ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw,
            0)
        if pad_h > 0 or pad_w > 0:
            x = F.pad(x, [
                pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2
            ])

        if self.use_coeff:
            if self.training:
                # set reuse to True for coefficient sharing
                if not self.reuse:
                    lr_conv1 = self.relu(self.bn1(self.conv1_stride_lr_1(x)))
                    # pdb.set_trace()
                    lr_conv1 = self.bn2(self.conv1_stride_lr_2(lr_conv1))
                    lr_conv1 = F.adaptive_avg_pool2d(lr_conv1, (1, 1))[:, :, 0,
                                                                       0]

                    self.coefficient.set_(
                        F.normalize(torch.mean(lr_conv1, 0),
                                    dim=0).clone().detach())
                    # pdb.set_trace()
                    self.coeff_grad = F.normalize(torch.mean(lr_conv1, 0),
                                                  dim=0)

                if self.rep_dim == 0:
                    if self.repeat_weight:
                        out = F.conv2d(
                            x,
                            self.generate_share_weight(self.weight,
                                                       self.rep_time,
                                                       self.coeff_grad,
                                                       self.out_channels_ori,
                                                       dim=0))
                    else:
                        out_tmp = F.conv2d(x, self.weight)
                        out = self.generate_share_feature(
                            self.rep_time, out_tmp,
                            F.normalize(torch.mean(lr_conv1, 0), dim=0),
                            self.out_channels_ori)
                        # out = F.conv2d(x, self.generate_share_feature(self.rep_time, F.normalize(torch.mean(lr_conv1, 0), dim = 0)))
                else:
                    if self.repeat_weight:
                        out = F.conv2d(
                            x,
                            self.generate_share_weight(self.weight,
                                                       self.rep_time,
                                                       self.coeff_grad,
                                                       x.shape[1],
                                                       dim=1))
                    else:
                        out_tmp = self.feature_wrapper(
                            self.rep_time, x,
                            F.normalize(torch.mean(lr_conv1, 0), dim=0),
                            self.out_channels_ori)
                        out = F.conv2d(out_tmp, self.weight)

            else:
                if self.rep_dim == 0:
                    if self.repeat_weight:
                        out = F.conv2d(
                            x,
                            self.generate_share_weight(
                                self.weight,
                                self.rep_time,
                                self.coefficient.detach(),
                                self.out_channels_ori,
                                dim=0))
                    else:
                        out_tmp = F.conv2d(x, self.weight)
                        out = self.generate_share_feature(
                            self.rep_time, out_tmp, self.coefficient.detach(),
                            self.out_channels_ori)
                        # out = F.conv2d(x, self.generate_share_feature(self.rep_time, self.coefficient.detach()))
                else:
                    if self.repeat_weight:
                        out = F.conv2d(
                            x,
                            self.generate_share_weight(
                                self.weight,
                                self.rep_time,
                                self.coefficient.detach(),
                                x.shape[1],
                                dim=1))
                    else:
                        out_tmp = self.feature_wrapper(
                            self.rep_time, out_tmp, self.coefficient.detach(),
                            self.out_channels_ori)

                        out = F.conv2d(out_tmp, self.weight)
        else:
            # print("use_coeff == False")
            if self.rep_dim == 0:
                out = F.conv2d(
                    x,
                    self.weight.repeat([self.rep_time, 1, 1,
                                        1])[:self.out_channels_ori, :, :, :],
                    None, 1)
            else:
                out = F.conv2d(
                    x,
                    self.weight.repeat([1, self.rep_time, 1,
                                        1])[:, :x.shape[1], :, :], None, 1)
        return out