Esempio n. 1
0
class NIN4d(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super(NIN4d, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight_v = Parameter(torch.Tensor(out_features, in_features))
        self.weight_g = Parameter(torch.Tensor(out_features, 1))
        if bias:
            self.bias = Parameter(torch.Tensor(out_features, 1, 1, 1, 1))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.normal_(self.weight_v, mean=0.0, std=0.05)
        self.weight_g.data.copy_(norm(self.weight_v, 0))
        if self.bias is not None:
            nn.init.constant_(self.bias, 0)

    def compute_weight(self):
        return self.weight_v * (self.weight_g / norm(self.weight_v, 0))

    def forward(self, input):
        weight = self.compute_weight()
        out = torch.einsum('bc...,oc->bo...', (input, weight))
        if self.bias is not None:
            out = out + self.bias
        return out

    def extra_repr(self):
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None)

    def init(self, x, init_scale=1.0):
        with torch.no_grad():
            out = self(x)
            batch, out_features = out.size()[:2]
            assert out_features == self.out_features
            # [batch, out_features, h * w] - > [batch, h * w, out_features]
            out = out.view(batch, out_features, -1).transpose(1, 2)
            # [batch*height*width, out_features]
            out = out.contiguous().view(-1, out_features)
            # [out_features]
            mean = out.mean(dim=0)
            std = out.std(dim=0)
            inv_stdv = init_scale / (std + 1e-6)

            self.weight_g.mul_(inv_stdv.unsqueeze(1))
            if self.bias is not None:
                mean = mean.view(out_features, 1, 1, 1, 1)
                inv_stdv = inv_stdv.view(out_features, 1, 1, 1, 1)
                self.bias.add_(-mean).mul_(inv_stdv)
            return self(x)
Esempio n. 2
0
class ActNorm2dFlow(Flow):
    def __init__(self, in_channels, inverse=False):
        super(ActNorm2dFlow, self).__init__(inverse)
        self.in_channels = in_channels
        self.log_scale = Parameter(torch.Tensor(in_channels, 1, 1))
        self.bias = Parameter(torch.Tensor(in_channels, 1, 1))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.normal_(self.log_scale, mean=0, std=0.05)
        nn.init.constant_(self.bias, 0.)

    @overrides
    def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """

        Args:
            input: Tensor
                input tensor [batch, in_channels, H, W]

        Returns: out: Tensor , logdet: Tensor
            out: [batch, in_channels, H, W], the output of the flow
            logdet: [batch], the log determinant of :math:`\partial output / \partial input`

        """
        batch, channels, H, W = input.size()
        out = input * self.log_scale.exp() + self.bias
        logdet = self.log_scale.sum(dim=0).squeeze(1).mul(H * W)
        return out, logdet

    @overrides
    def backward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """

        Args:
            input: Tensor
                input tensor [batch, in_channels, H, W]

        Returns: out: Tensor , logdet: Tensor
            out: [batch, in_channels, H, W], the output of the flow
            logdet: [batch], the log determinant of :math:`\partial output / \partial input`

        """
        batch, channels, H, W = input.size()
        out = input - self.bias
        out = out.div(self.log_scale.exp() + 1e-8)
        logdet = self.log_scale.sum(dim=0).squeeze(1).mul(H * -W)
        return out, logdet

    @overrides
    def init(self, data, init_scale=1.0) -> Tuple[torch.Tensor, torch.Tensor]:
        with torch.no_grad():
            # [batch, n_channels, H, W]
            out, _ = self.forward(data)
            out = out.transpose(0, 1).contiguous().view(self.in_channels, -1)
            # [n_channels, 1, 1]
            mean = out.mean(dim=1).view(self.in_channels, 1, 1)
            std = out.std(dim=1).view(self.in_channels, 1, 1)
            inv_stdv = init_scale / (std + 1e-6)

            self.log_scale.add_(inv_stdv.log())
            self.bias.add_(-mean).mul_(inv_stdv)
            return self.forward(data)

    @overrides
    def extra_repr(self):
        return 'inverse={}, in_channels={}'.format(self.inverse, self.in_channels)

    @classmethod
    def from_params(cls, params: Dict) -> "ActNorm2dFlow":
        return ActNorm2dFlow(**params)
Esempio n. 3
0
class ActNormFlow(Flow):
    def __init__(self, in_features, inverse=False):
        super(ActNormFlow, self).__init__(inverse)
        self.in_features = in_features
        self.log_scale = Parameter(torch.Tensor(in_features))
        self.bias = Parameter(torch.Tensor(in_features))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.normal_(self.log_scale, mean=0, std=0.05)
        nn.init.constant_(self.bias, 0.)

    @overrides
    def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """

        Args:
            input: Tensor
                input tensor [batch, N1, N2, ..., in_channels]

        Returns: out: Tensor , logdet: Tensor
            out: [batch, N1, N2, ..., in_channels], the output of the flow
            logdet: [batch], the log determinant of :math:`\partial output / \partial input`

        """
        out = input * self.log_scale.exp() + self.bias
        logdet = self.log_scale.sum(dim=0, keepdim=True)
        if input.dim() > 2:
            num = np.prod(input.size()[1:-1])
            logdet = logdet * num.astype(float)
        return out, logdet

    @overrides
    def backward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """

        Args:
            input: input: Tensor
                input tensor [batch, N1, N2, ..., in_channels]

        Returns: out: Tensor , logdet: Tensor
            out: [batch, N1, N2, ..., in_channels], the output of the flow
            logdet: [batch], the log determinant of :math:`\partial output / \partial input`

        """
        out = input - self.bias
        out = out.div(self.log_scale.exp() + 1e-8)
        logdet = self.log_scale.sum(dim=0, keepdim=True) * -1.0
        if input.dim() > 2:
            num = np.prod(input.size()[1:-1])
            logdet = logdet * num.astype(float)
        return out, logdet

    @overrides
    def init(self, data, init_scale=1.0) -> Tuple[torch.Tensor, torch.Tensor]:
        """

        Args:
            data: input: Tensor
                input tensor [batch, N1, N2, ..., in_channels]

        Returns: out: Tensor , logdet: Tensor
            out: [batch, N1, N2, ..., in_channels], the output of the flow
            logdet: [batch], the log determinant of :math:`\partial output / \partial input`

        """
        with torch.no_grad():
            out, _ = self.forward(data)
            mean = out.view(-1, self.in_features).mean(dim=0)
            std = out.view(-1, self.in_features).std(dim=0)
            inv_stdv = init_scale / (std + 1e-6)

            self.log_scale.add_(inv_stdv.log())
            self.bias.add_(-mean).mul_(inv_stdv)
            return self.forward(data)

    @overrides
    def extra_repr(self):
        return 'inverse={}, in_features={}'.format(self.inverse, self.in_features)

    @classmethod
    def from_params(cls, params: Dict) -> "ActNormFlow":
        return ActNormFlow(**params)
Esempio n. 4
0
class MaskedConv2d(nn.Module):
    """
    Conv2d with mask and weight normalization.
    """
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 mask_type='A',
                 order='A',
                 masked_channels=None,
                 stride=1,
                 dilation=1,
                 groups=1):
        super(MaskedConv2d, self).__init__()
        assert mask_type in {'A', 'B'}
        assert order in {'A', 'B'}
        self.mask_type = mask_type
        self.order = order
        kernel_size = _pair(kernel_size)
        for k in kernel_size:
            assert k % 2 == 1, 'kernel cannot include even number: {}'.format(
                self.kernel_size)
        padding = (kernel_size[0] // 2, kernel_size[1] // 2)
        stride = _pair(stride)
        dilation = _pair(dilation)

        if in_channels % groups != 0:
            raise ValueError('in_channels must be divisible by groups')
        if out_channels % groups != 0:
            raise ValueError('out_channels must be divisible by groups')
        self.in_channels = in_channels
        self.out_channels = out_channels
        # masked all input channels by default
        masked_channels = in_channels if masked_channels is None else masked_channels
        self.masked_channels = masked_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups

        self.weight_v = Parameter(
            torch.Tensor(out_channels, in_channels // groups, *kernel_size))
        self.weight_g = Parameter(torch.Tensor(out_channels, 1, 1, 1))
        self.bias = Parameter(torch.Tensor(out_channels))

        self.register_buffer('mask', torch.ones(self.weight_v.size()))
        _, _, kH, kW = self.weight_v.size()
        mask = np.ones([*self.mask.size()], dtype=np.float32)
        mask[:, :masked_channels, kH // 2, kW // 2 + (mask_type == 'B'):] = 0
        mask[:, :masked_channels, kH // 2 + 1:] = 0

        # reverse order
        if order == 'B':
            reverse_mask = mask[:, :, ::-1, :]
            reverse_mask = reverse_mask[:, :, :, ::-1]
            mask = reverse_mask.copy()
        self.mask.copy_(torch.from_numpy(mask).float())
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.normal_(self.weight_v, mean=0.0, std=0.05)
        self.weight_v.data.mul_(self.mask)
        _norm = norm(self.weight_v, 0).data + 1e-8
        self.weight_g.data.copy_(_norm.log())
        nn.init.constant_(self.bias, 0)

    def init(self, x, init_scale=1.0):
        with torch.no_grad():
            # [batch, n_channels, H, W]
            out = self(x)
            n_channels = out.size(1)
            out = out.transpose(0, 1).contiguous().view(n_channels, -1)
            # [n_channels]
            mean = out.mean(dim=1)
            std = out.std(dim=1)
            inv_stdv = init_scale / (std + 1e-6)
            self.weight_g.add_(inv_stdv.log().view(n_channels, 1, 1, 1))
            self.bias.add_(-mean).mul_(inv_stdv)
            return self(x)

    def forward(self, input):
        self.weight_v.data.mul_(self.mask)
        _norm = norm(self.weight_v, 0) + 1e-8
        weight = self.weight_v * (self.weight_g.exp() / _norm)
        return F.conv2d(input, weight, self.bias, self.stride, self.padding,
                        self.dilation, self.groups)

    def extra_repr(self):
        s = (
            '{in_channels}({masked_channels}), {out_channels}, kernel_size={kernel_size}'
            ', stride={stride}')
        if self.padding != (0, ) * len(self.padding):
            s += ', padding={padding}'
        if self.dilation != (1, ) * len(self.dilation):
            s += ', dilation={dilation}'
        if self.groups != 1:
            s += ', groups={groups}'
        if self.bias is None:
            s += ', bias=False'
        s += ', type={mask_type}, order={order}'
        return s.format(**self.__dict__)
Esempio n. 5
0
class ActNorm1dFlow(Flow):
    def __init__(self, in_features, inverse=False):
        super(ActNorm1dFlow, self).__init__(inverse)
        self.in_features = in_features
        self.log_scale = Parameter(torch.Tensor(in_features))
        self.bias = Parameter(torch.Tensor(in_features))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.normal_(self.log_scale, mean=0, std=0.05)
        nn.init.constant_(self.bias, 0.)

    @overrides
    def forward(self, input: torch.Tensor, mask: Union[torch.Tensor, None] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        """

        Args:
            input: Tensor
                input tensor [batch, N1, N2, ..., in_channels]
            mask: Tensor or None
                mask tensor [batch, N1, N2, ...,Nl]

        Returns: out: Tensor , logdet: Tensor
            out: [batch, N1, N2, ..., in_channels], the output of the flow
            logdet: [batch], the log determinant of :math:`\partial output / \partial input`

        """
        dim = input.dim()
        out = input * self.log_scale.exp() + self.bias
        if mask is not None:
            out = out * mask.unsqueeze(dim - 1)

        logdet = self.log_scale.sum(dim=0, keepdim=True)
        if dim > 2:
            num = np.prod(input.size()[1:-1]).astype(float) if mask is None else mask.view(out.size(0), -1).sum(dim=1)
            logdet = logdet * num
        return out, logdet

    @overrides
    def backward(self, input: torch.Tensor, mask: Union[torch.Tensor, None] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        """

        Args:
            input: input: Tensor
                input tensor [batch, N1, N2, ..., in_channels]
            mask: Tensor or None
                mask tensor [batch, N1, N2, ...,Nl]

        Returns: out: Tensor , logdet: Tensor
            out: [batch, N1, N2, ..., in_channels], the output of the flow
            logdet: [batch], the log determinant of :math:`\partial output / \partial input`

        """
        dim = input.dim()
        out = (input - self.bias).div(self.log_scale.exp() + 1e-8)
        if mask is not None:
            out = out * mask.unsqueeze(dim - 1)

        logdet = self.log_scale.sum(dim=0, keepdim=True) * -1.0
        if input.dim() > 2:
            num = np.prod(input.size()[1:-1]).astype(float) if mask is None else mask.view(out.size(0), -1).sum(dim=1)
            logdet = logdet * num
        return out, logdet

    @overrides
    def init(self, data, mask: Union[torch.Tensor, None] = None, init_scale=1.0) -> Tuple[torch.Tensor, torch.Tensor]:
        """

        Args:
            data: input: Tensor
                input tensor [batch, N1, N2, ..., in_channels]
            mask: Tensor or None
                mask tensor [batch, N1, N2, ...,Nl]
            init_scale: float
                initial scale

        Returns: out: Tensor , logdet: Tensor
            out: [batch, N1, N2, ..., in_channels], the output of the flow
            logdet: [batch], the log determinant of :math:`\partial output / \partial input`

        """
        with torch.no_grad():
            # [batch * N1 * ... * Nl, in_features]
            out, _ = self.forward(data, mask=mask)
            out = out.view(-1, self.in_features)
            mean = out.mean(dim=0)
            std = out.std(dim=0)
            inv_stdv = init_scale / (std + 1e-6)

            self.log_scale.add_(inv_stdv.log())
            self.bias.add_(-mean).mul_(inv_stdv)
            return self.forward(data, mask=mask)

    @overrides
    def extra_repr(self):
        return 'inverse={}, in_features={}'.format(self.inverse, self.in_features)

    @classmethod
    def from_params(cls, params: Dict) -> "ActNorm1dFlow":
        return ActNorm1dFlow(**params)
Esempio n. 6
0
class MaskedLinear(nn.Module):
    """
    masked linear module with weight normalization
    """
    def __init__(self, in_features, out_features, mask_type, total_units, max_units=None, bias=True):
        """
        Args:
            in_features: number of units in the inputs
            out_features: number of units in the outputs.
            max_units: the list containing the maximum units each input unit depends on.
            mask_type: type of the masked linear.
            total_units: the total number of units to assign.
            bias: using bias vector.
        """
        super(MaskedLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features

        self.weight_v = Parameter(torch.Tensor(out_features, in_features))
        self.weight_g = Parameter(torch.Tensor(out_features, 1))
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)

        layer_type, order = mask_type
        self.layer_type = layer_type
        self.order = order
        assert layer_type in {'input-hidden', 'hidden-hidden', 'hidden-output', 'input-output'}
        assert order in {'A', 'B'}
        self.register_buffer('mask', self.weight_v.data.clone())

        # override the max_units for input layer
        if layer_type.startswith('input'):
            max_units = np.arange(in_features) + 1
        else:
            assert max_units is not None and len(max_units) == in_features

        if layer_type.endswith('output'):
            assert out_features > total_units
            self.max_units = np.arange(out_features)
            self.max_units[total_units:] = total_units
        else:
            units_per_units = float(total_units) / out_features
            self.max_units = np.zeros(out_features, dtype=np.int32)
            for i in range(out_features):
                self.max_units[i] = np.ceil((i + 1) * units_per_units)

        mask = np.zeros([out_features, in_features], dtype=np.float32)
        for i in range(out_features):
            for j in range(in_features):
                mask[i, j] = float(self.max_units[i] >= max_units[j])

        # reverse order
        if order == 'B':
            reverse_mask = mask[::-1, :]
            reverse_mask = reverse_mask[:, ::-1]
            mask = np.copy(reverse_mask)

        self.mask.copy_(torch.from_numpy(mask).float())
        self.reset_parameters()

        self._init = True

    def reset_parameters(self):
        nn.init.normal_(self.weight_v, mean=0.0, std=0.05)
        self.weight_v.data.mul_(self.mask)
        _norm = norm(self.weight_v, 0).data + 1e-8
        self.weight_g.data.copy_(_norm.log())
        if self.bias is not None:
            nn.init.constant_(self.bias, 0.)

    def initialize(self, x, init_scale=1.0):
        with torch.no_grad():
            # [batch, out_features]
            out = self(x)
            # [out_features]
            mean = out.mean(dim=0)
            std = out.std(dim=0)
            std = std + std.le(0).float()
            inv_stdv = init_scale / (std + 1e-6)

            self.weight_g.add_(inv_stdv.log().unsqueeze(1))
            if self.bias is not None:
                self.bias.add_(-mean).mul_(inv_stdv)
            return self(x)

    def forward(self, input):
        self.weight_v.data.mul_(self.mask)
        _norm = norm(self.weight_v, 0) + 1e-8
        weight = self.weight_v * (self.weight_g.exp() / _norm)
        return F.linear(input, weight, self.bias)

    @overrides
    def extra_repr(self):
        return 'in_features={}, out_features={}, bias={}, type={}, order={}'.format(
            self.in_features, self.out_features, self.bias is not None,
            self.layer_type, self.order
        )