Beispiel #1
0
class VDropCentralData(nn.Module):
    """
    Stores data for a set of variational dropout (VDrop) modules in large
    central tensors. The VDrop modules access the data using views. This makes
    it possible to operate on all of the data at once, (rather than e.g. 53
    times with resnet50).

    Usage:
    1. Instantiate
    2. Pass into multiple constructed VDropLinear and VDropConv2d modules
    3. Call finalize

    Before calling forward on the model, call "compute_forward_data".
    After calling forward on the model, call "clear_forward_data".

    The parameters are stored in terms of z_mu and z_var rather than w_mu and
    w_var to support group variational dropout (e.g. to allow for pruning entire
    channels.)
    """
    def __init__(self, z_logvar_init=-10):
        super().__init__()
        self.z_chunk_sizes = []
        self.z_logvar_init = z_logvar_init
        self.z_logvar_min = min(z_logvar_init, -10)
        self.z_logvar_max = 10.
        self.epsilon = 1e-8
        self.data_views = {}
        self.modules = []

        # Populated during register(), deleted during finalize()
        self.all_z_mu = []
        self.all_z_logvar = []
        self.all_num_weights = []

        # Populated during finalize()
        self.z_mu = None
        self.z_logvar = None
        self.z_num_weights = None

        self.threshold = 3

    def extra_repr(self):
        s = f"z_logvar_init={self.z_logvar_init}"
        return s

    def __getitem__(self, key):
        return self.data_views[key]

    def register(self, module, z_mu, z_logvar, num_weights_per_z=1):
        self.all_z_mu.append(z_mu.flatten())
        self.all_z_logvar.append(z_logvar.flatten())
        self.all_num_weights.append(num_weights_per_z)

        self.modules.append(module)
        data_index = len(self.z_chunk_sizes)
        self.z_chunk_sizes.append(z_mu.numel())

        return data_index

    def finalize(self):
        self.z_mu = Parameter(torch.cat(self.all_z_mu))
        self.z_logvar = Parameter(torch.cat(self.all_z_logvar))
        self.z_num_weights = torch.tensor(self.all_num_weights,
                                          dtype=torch.float).repeat_interleave(
                                              torch.tensor(self.z_chunk_sizes))
        del self.all_z_mu
        del self.all_z_logvar
        del self.all_num_weights

    def to(self, *args, **kwargs):
        ret = super().to(*args, **kwargs)
        self.z_num_weights = self.z_num_weights.to(*args, **kwargs)
        return ret

    def compute_forward_data(self):
        if self.training:
            self.data_views["z_mu"] = self.z_mu.split(self.z_chunk_sizes)
            self.data_views["z_var"] = self.z_logvar.exp().split(
                self.z_chunk_sizes)
        else:
            self.data_views["z_mu"] = (
                self.z_mu *
                (self.compute_z_logalpha() < self.threshold).float()).split(
                    self.z_chunk_sizes)

    def clear_forward_data(self):
        self.data_views.clear()

    def compute_z_logalpha(self):
        return self.z_logvar - (self.z_mu.square() + self.epsilon).log()

    def regularization(self):
        return (vdrop_regularization(self.compute_z_logalpha()) *
                self.z_num_weights).sum()

    def constrain_parameters(self):
        self.z_logvar.data.clamp_(min=self.z_logvar_min, max=self.z_logvar_max)
Beispiel #2
0
class VDropLinear2(nn.Module):
    """
    A self-contained VDropLinear (doesn't use the VDropCentralData)
    """
    def __init__(self,
                 in_features,
                 out_features,
                 bias=True,
                 w_logvar_init=-10):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features

        self.w_logvar_min = min(w_logvar_init, -10)
        self.w_logvar_max = 10.
        self.pruned_logvar_sentinel = self.w_logvar_max - 0.00058
        self.epsilon = 1e-8

        self.w_mu = Parameter(torch.Tensor(self.out_features,
                                           self.in_features))
        self.w_logvar = Parameter(
            torch.Tensor(self.out_features, self.in_features))
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.bias = None

        self.w_logvar.data.fill_(w_logvar_init)
        # Standard nn.Linear initialization.
        init.kaiming_uniform_(self.w_mu, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.w_mu)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

        self.tensor_constructor = (torch.FloatTensor
                                   if not torch.cuda.is_available() else
                                   torch.cuda.FloatTensor)

    def extra_repr(self):
        s = f"{self.in_features}, {self.out_features}, "
        if self.bias is None:
            s += ", bias=False"
        return s

    def get_w_mu(self):
        return self.w_mu

    def get_w_var(self):
        return self.w_logvar.exp()

    def forward(self, x):
        if self.training:
            return vdrop_linear_forward(x, self.get_w_mu, self.get_w_var,
                                        self.bias, self.tensor_constructor)
        else:
            return F.linear(x, self.get_w_mu(), self.bias)

    def compute_w_logalpha(self):
        return self.w_logvar - (self.w_mu.square() + self.epsilon).log()

    def regularization(self):
        return vdrop_regularization(self.compute_w_logalpha()).sum()

    def constrain_parameters(self):
        self.w_logvar.data.clamp_(min=self.w_logvar_min, max=self.w_logvar_max)
Beispiel #3
0
class MaskedVDropConv2d(nn.Module):
    """
    A self-contained masked Conv2d (doesn't use the VDropCentralData)
    """
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 groups=1,
                 bias=True,
                 mask=None,
                 w_logvar_init=-10):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = pair(kernel_size)
        self.stride = pair(stride)
        self.padding = pair(padding)
        self.dilation = pair(dilation)
        self.groups = groups

        self.w_logvar_min = min(w_logvar_init, -10)
        self.w_logvar_max = 10.
        self.pruned_logvar_sentinel = self.w_logvar_max - 0.00058
        self.epsilon = 1e-8

        self.w_mu = Parameter(
            torch.Tensor(out_channels, in_channels // groups,
                         *self.kernel_size))
        self.w_logvar = Parameter(
            torch.Tensor(out_channels, in_channels // groups,
                         *self.kernel_size))
        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.bias = None

        self.w_logvar.data.fill_(w_logvar_init)

        self.register_buffer(
            "w_mask",
            torch.HalfTensor(out_channels, in_channels // groups,
                             *self.kernel_size))

        # Standard nn.Conv2d initialization.
        init.kaiming_uniform_(self.w_mu, a=math.sqrt(5))

        if mask is not None:
            self.w_mask[:] = mask
            self.w_mu.data *= self.w_mask
            self.w_logvar.data[self.w_mask ==
                               0.0] = self.pruned_logvar_sentinel
        else:
            self.w_mask.fill_(1.0)

        # Standard nn.Conv2d initialization.
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.w_mu)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)

        self.tensor_constructor = (torch.FloatTensor
                                   if not torch.cuda.is_available() else
                                   torch.cuda.FloatTensor)

    def extra_repr(self):
        s = (f"{self.in_channels}, {self.out_channels}, "
             f"kernel_size={self.kernel_size}, stride={self.stride}")
        if self.padding != (0, ) * len(self.padding):
            s += f", padding={self.padding}"
        if self.dilation != (1, ) * len(self.dilation):
            s += f", dilation={self.dilation}"
        if self.groups != 1:
            s += f", groups={self.groups}"
        if self.bias is None:
            s += ", bias=False"
        return s

    def get_w_mu(self):
        return self.w_mu * self.w_mask

    def get_w_var(self):
        return self.w_logvar.exp() * self.w_mask

    def forward(self, x):
        if self.training:
            return vdrop_conv_forward(x, self.get_w_mu, self.get_w_var,
                                      self.bias, self.stride, self.padding,
                                      self.dilation, self.groups,
                                      self.tensor_constructor)
        else:
            return F.conv2d(x, self.get_w_mu(), self.bias, self.stride,
                            self.padding, self.dilation, self.groups)

    def compute_w_logalpha(self):
        return self.w_logvar - (self.w_mu.square() + self.epsilon).log()

    def regularization(self):
        return (vdrop_regularization(self.compute_w_logalpha()) *
                self.w_mask).sum()

    def constrain_parameters(self):
        self.w_logvar.data.clamp_(min=self.w_logvar_min, max=self.w_logvar_max)
Beispiel #4
0
class Linear(nn.Module):

    __constants__ = ['in_features', 'out_features']
    in_features: int
    out_features: int
    weight: Tensor

    def __init__(self, in_features: int, out_features: int, bias: bool = True, activation="ReLU", hidden_dim=None, hidden_activation="ReLU") -> None:
        super(Linear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.hidden_dim = hidden_dim
        self.hidden_activation = hidden_activation
        if hidden_dim is None:
            self.dims = vector(in_features, out_features)
            self.weight = Parameter(torch.zeros(out_features, in_features))
            if bias:
                self.bias = Parameter(torch.zeros(out_features))
            else:
                self.register_parameter('bias', None)
            self.activation = get_activation_layer(activation)
        else:
            self.dims = vector(in_features, *vector(hidden_dim), out_features)
            self.weight = nn.ParameterList(self.dims.map_k(lambda in_dim, out_dim: Parameter(torch.zeros(out_dim, in_dim)), 2))
            if bias:
                self.bias = nn.ParameterList(self.dims.map_k(lambda in_dim, out_dim: Parameter(torch.zeros(out_dim)), 2))
            else:
                self.register_parameter('bias', None)
            self.activation = vector(get_activation_layer(hidden_activation) for _ in range(len(hidden_dim)))
            self.activation.append(get_activation_layer(activation))

        self.reset_parameters()

    def reset_parameters(self) -> None:
        if self.hidden_dim is None:
            if isinstance(self.activation, torch.nn.ReLU) or self.activation == torch.relu:
                init.kaiming_normal_(self.weight, a=0, mode='fan_in', nonlinearity='relu')
            else:
                init.xavier_normal_(self.weight)
        else:
            for a, w in zip(self.activation, self.weight):
                if isinstance(a, torch.nn.ReLU) or a == torch.relu:
                    init.kaiming_normal_(w, a=0, mode='fan_in', nonlinearity='relu')
                else:
                    init.xavier_normal_(w)

    def forward(self, input: Tensor) -> Tensor:
        if self.hidden_dim is None:
            if self.activation is None:
                return F.linear(input, self.weight, self.bias)
            else:
                return self.activation(F.linear(input, self.weight, self.bias))
        else:
            h = input
            if self.bias is None:
                for w, a in zip(self.weight, self.activation):
                    h = a(F.linear(h, w, None))
            else:
                for w, b, a in zip(self.weight, self.bias, self.activation):
                    h = a(F.linear(h, w, b))
            return h

    def extra_repr(self) -> str:
        if self.activation is None:
            return 'in_features={}, out_features={}, bias={}'.format(self.in_features, self.out_features, self.bias is not None)
        elif isinstance(self.activation, vector):
            ret = 'in_features={}, out_features={}, bias={}, activation={}\n'.format(self.in_features, self.out_features, self.bias is not None, self.activation.map(lambda x: touch(lambda: x.__name__, str(x))))
            ret += "{}".format(self.in_features)
            for d, a in zip(self.dims[1:], self.activation):
                ret += '->{}->{}'.format(d, touch(lambda: a.__name__, str(a)))
            return ret
        else:
            ret = 'in_features={}, out_features={}, bias={}, activation={}'.format(self.in_features, self.out_features, self.bias is not None, touch(lambda: self.activation.__name__, str(self.activation)))
            return ret

    def regulization_loss(self, p=2):
        if self.hidden_dim is None:
            if p == 2:
                return self.weight.square().sum()
            if p == 1:
                return self.weight.abs().sum()
            return (self.weight.abs() ** p).sum()
        else:
            reg = []
            for w in self.weight:
                reg.append((w.abs() ** p).sum())
            return sum(reg)