class NIN4d(nn.Module): def __init__(self, in_features, out_features, bias=True): super(NIN4d, self).__init__() self.in_features = in_features self.out_features = out_features self.weight_v = Parameter(torch.Tensor(out_features, in_features)) self.weight_g = Parameter(torch.Tensor(out_features, 1)) if bias: self.bias = Parameter(torch.Tensor(out_features, 1, 1, 1, 1)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): nn.init.normal_(self.weight_v, mean=0.0, std=0.05) self.weight_g.data.copy_(norm(self.weight_v, 0)) if self.bias is not None: nn.init.constant_(self.bias, 0) def compute_weight(self): return self.weight_v * (self.weight_g / norm(self.weight_v, 0)) def forward(self, input): weight = self.compute_weight() out = torch.einsum('bc...,oc->bo...', (input, weight)) if self.bias is not None: out = out + self.bias return out def extra_repr(self): return 'in_features={}, out_features={}, bias={}'.format( self.in_features, self.out_features, self.bias is not None) def init(self, x, init_scale=1.0): with torch.no_grad(): out = self(x) batch, out_features = out.size()[:2] assert out_features == self.out_features # [batch, out_features, h * w] - > [batch, h * w, out_features] out = out.view(batch, out_features, -1).transpose(1, 2) # [batch*height*width, out_features] out = out.contiguous().view(-1, out_features) # [out_features] mean = out.mean(dim=0) std = out.std(dim=0) inv_stdv = init_scale / (std + 1e-6) self.weight_g.mul_(inv_stdv.unsqueeze(1)) if self.bias is not None: mean = mean.view(out_features, 1, 1, 1, 1) inv_stdv = inv_stdv.view(out_features, 1, 1, 1, 1) self.bias.add_(-mean).mul_(inv_stdv) return self(x)
class ActNorm2dFlow(Flow): def __init__(self, in_channels, inverse=False): super(ActNorm2dFlow, self).__init__(inverse) self.in_channels = in_channels self.log_scale = Parameter(torch.Tensor(in_channels, 1, 1)) self.bias = Parameter(torch.Tensor(in_channels, 1, 1)) self.reset_parameters() def reset_parameters(self): nn.init.normal_(self.log_scale, mean=0, std=0.05) nn.init.constant_(self.bias, 0.) @overrides def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: input: Tensor input tensor [batch, in_channels, H, W] Returns: out: Tensor , logdet: Tensor out: [batch, in_channels, H, W], the output of the flow logdet: [batch], the log determinant of :math:`\partial output / \partial input` """ batch, channels, H, W = input.size() out = input * self.log_scale.exp() + self.bias logdet = self.log_scale.sum(dim=0).squeeze(1).mul(H * W) return out, logdet @overrides def backward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: input: Tensor input tensor [batch, in_channels, H, W] Returns: out: Tensor , logdet: Tensor out: [batch, in_channels, H, W], the output of the flow logdet: [batch], the log determinant of :math:`\partial output / \partial input` """ batch, channels, H, W = input.size() out = input - self.bias out = out.div(self.log_scale.exp() + 1e-8) logdet = self.log_scale.sum(dim=0).squeeze(1).mul(H * -W) return out, logdet @overrides def init(self, data, init_scale=1.0) -> Tuple[torch.Tensor, torch.Tensor]: with torch.no_grad(): # [batch, n_channels, H, W] out, _ = self.forward(data) out = out.transpose(0, 1).contiguous().view(self.in_channels, -1) # [n_channels, 1, 1] mean = out.mean(dim=1).view(self.in_channels, 1, 1) std = out.std(dim=1).view(self.in_channels, 1, 1) inv_stdv = init_scale / (std + 1e-6) self.log_scale.add_(inv_stdv.log()) self.bias.add_(-mean).mul_(inv_stdv) return self.forward(data) @overrides def extra_repr(self): return 'inverse={}, in_channels={}'.format(self.inverse, self.in_channels) @classmethod def from_params(cls, params: Dict) -> "ActNorm2dFlow": return ActNorm2dFlow(**params)
class ActNormFlow(Flow): def __init__(self, in_features, inverse=False): super(ActNormFlow, self).__init__(inverse) self.in_features = in_features self.log_scale = Parameter(torch.Tensor(in_features)) self.bias = Parameter(torch.Tensor(in_features)) self.reset_parameters() def reset_parameters(self): nn.init.normal_(self.log_scale, mean=0, std=0.05) nn.init.constant_(self.bias, 0.) @overrides def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: input: Tensor input tensor [batch, N1, N2, ..., in_channels] Returns: out: Tensor , logdet: Tensor out: [batch, N1, N2, ..., in_channels], the output of the flow logdet: [batch], the log determinant of :math:`\partial output / \partial input` """ out = input * self.log_scale.exp() + self.bias logdet = self.log_scale.sum(dim=0, keepdim=True) if input.dim() > 2: num = np.prod(input.size()[1:-1]) logdet = logdet * num.astype(float) return out, logdet @overrides def backward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: input: input: Tensor input tensor [batch, N1, N2, ..., in_channels] Returns: out: Tensor , logdet: Tensor out: [batch, N1, N2, ..., in_channels], the output of the flow logdet: [batch], the log determinant of :math:`\partial output / \partial input` """ out = input - self.bias out = out.div(self.log_scale.exp() + 1e-8) logdet = self.log_scale.sum(dim=0, keepdim=True) * -1.0 if input.dim() > 2: num = np.prod(input.size()[1:-1]) logdet = logdet * num.astype(float) return out, logdet @overrides def init(self, data, init_scale=1.0) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: data: input: Tensor input tensor [batch, N1, N2, ..., in_channels] Returns: out: Tensor , logdet: Tensor out: [batch, N1, N2, ..., in_channels], the output of the flow logdet: [batch], the log determinant of :math:`\partial output / \partial input` """ with torch.no_grad(): out, _ = self.forward(data) mean = out.view(-1, self.in_features).mean(dim=0) std = out.view(-1, self.in_features).std(dim=0) inv_stdv = init_scale / (std + 1e-6) self.log_scale.add_(inv_stdv.log()) self.bias.add_(-mean).mul_(inv_stdv) return self.forward(data) @overrides def extra_repr(self): return 'inverse={}, in_features={}'.format(self.inverse, self.in_features) @classmethod def from_params(cls, params: Dict) -> "ActNormFlow": return ActNormFlow(**params)
class MaskedConv2d(nn.Module): """ Conv2d with mask and weight normalization. """ def __init__(self, in_channels, out_channels, kernel_size, mask_type='A', order='A', masked_channels=None, stride=1, dilation=1, groups=1): super(MaskedConv2d, self).__init__() assert mask_type in {'A', 'B'} assert order in {'A', 'B'} self.mask_type = mask_type self.order = order kernel_size = _pair(kernel_size) for k in kernel_size: assert k % 2 == 1, 'kernel cannot include even number: {}'.format( self.kernel_size) padding = (kernel_size[0] // 2, kernel_size[1] // 2) stride = _pair(stride) dilation = _pair(dilation) if in_channels % groups != 0: raise ValueError('in_channels must be divisible by groups') if out_channels % groups != 0: raise ValueError('out_channels must be divisible by groups') self.in_channels = in_channels self.out_channels = out_channels # masked all input channels by default masked_channels = in_channels if masked_channels is None else masked_channels self.masked_channels = masked_channels self.kernel_size = kernel_size self.stride = stride self.padding = padding self.dilation = dilation self.groups = groups self.weight_v = Parameter( torch.Tensor(out_channels, in_channels // groups, *kernel_size)) self.weight_g = Parameter(torch.Tensor(out_channels, 1, 1, 1)) self.bias = Parameter(torch.Tensor(out_channels)) self.register_buffer('mask', torch.ones(self.weight_v.size())) _, _, kH, kW = self.weight_v.size() mask = np.ones([*self.mask.size()], dtype=np.float32) mask[:, :masked_channels, kH // 2, kW // 2 + (mask_type == 'B'):] = 0 mask[:, :masked_channels, kH // 2 + 1:] = 0 # reverse order if order == 'B': reverse_mask = mask[:, :, ::-1, :] reverse_mask = reverse_mask[:, :, :, ::-1] mask = reverse_mask.copy() self.mask.copy_(torch.from_numpy(mask).float()) self.reset_parameters() def reset_parameters(self): nn.init.normal_(self.weight_v, mean=0.0, std=0.05) self.weight_v.data.mul_(self.mask) _norm = norm(self.weight_v, 0).data + 1e-8 self.weight_g.data.copy_(_norm.log()) nn.init.constant_(self.bias, 0) def init(self, x, init_scale=1.0): with torch.no_grad(): # [batch, n_channels, H, W] out = self(x) n_channels = out.size(1) out = out.transpose(0, 1).contiguous().view(n_channels, -1) # [n_channels] mean = out.mean(dim=1) std = out.std(dim=1) inv_stdv = init_scale / (std + 1e-6) self.weight_g.add_(inv_stdv.log().view(n_channels, 1, 1, 1)) self.bias.add_(-mean).mul_(inv_stdv) return self(x) def forward(self, input): self.weight_v.data.mul_(self.mask) _norm = norm(self.weight_v, 0) + 1e-8 weight = self.weight_v * (self.weight_g.exp() / _norm) return F.conv2d(input, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) def extra_repr(self): s = ( '{in_channels}({masked_channels}), {out_channels}, kernel_size={kernel_size}' ', stride={stride}') if self.padding != (0, ) * len(self.padding): s += ', padding={padding}' if self.dilation != (1, ) * len(self.dilation): s += ', dilation={dilation}' if self.groups != 1: s += ', groups={groups}' if self.bias is None: s += ', bias=False' s += ', type={mask_type}, order={order}' return s.format(**self.__dict__)
class ActNorm1dFlow(Flow): def __init__(self, in_features, inverse=False): super(ActNorm1dFlow, self).__init__(inverse) self.in_features = in_features self.log_scale = Parameter(torch.Tensor(in_features)) self.bias = Parameter(torch.Tensor(in_features)) self.reset_parameters() def reset_parameters(self): nn.init.normal_(self.log_scale, mean=0, std=0.05) nn.init.constant_(self.bias, 0.) @overrides def forward(self, input: torch.Tensor, mask: Union[torch.Tensor, None] = None) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: input: Tensor input tensor [batch, N1, N2, ..., in_channels] mask: Tensor or None mask tensor [batch, N1, N2, ...,Nl] Returns: out: Tensor , logdet: Tensor out: [batch, N1, N2, ..., in_channels], the output of the flow logdet: [batch], the log determinant of :math:`\partial output / \partial input` """ dim = input.dim() out = input * self.log_scale.exp() + self.bias if mask is not None: out = out * mask.unsqueeze(dim - 1) logdet = self.log_scale.sum(dim=0, keepdim=True) if dim > 2: num = np.prod(input.size()[1:-1]).astype(float) if mask is None else mask.view(out.size(0), -1).sum(dim=1) logdet = logdet * num return out, logdet @overrides def backward(self, input: torch.Tensor, mask: Union[torch.Tensor, None] = None) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: input: input: Tensor input tensor [batch, N1, N2, ..., in_channels] mask: Tensor or None mask tensor [batch, N1, N2, ...,Nl] Returns: out: Tensor , logdet: Tensor out: [batch, N1, N2, ..., in_channels], the output of the flow logdet: [batch], the log determinant of :math:`\partial output / \partial input` """ dim = input.dim() out = (input - self.bias).div(self.log_scale.exp() + 1e-8) if mask is not None: out = out * mask.unsqueeze(dim - 1) logdet = self.log_scale.sum(dim=0, keepdim=True) * -1.0 if input.dim() > 2: num = np.prod(input.size()[1:-1]).astype(float) if mask is None else mask.view(out.size(0), -1).sum(dim=1) logdet = logdet * num return out, logdet @overrides def init(self, data, mask: Union[torch.Tensor, None] = None, init_scale=1.0) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: data: input: Tensor input tensor [batch, N1, N2, ..., in_channels] mask: Tensor or None mask tensor [batch, N1, N2, ...,Nl] init_scale: float initial scale Returns: out: Tensor , logdet: Tensor out: [batch, N1, N2, ..., in_channels], the output of the flow logdet: [batch], the log determinant of :math:`\partial output / \partial input` """ with torch.no_grad(): # [batch * N1 * ... * Nl, in_features] out, _ = self.forward(data, mask=mask) out = out.view(-1, self.in_features) mean = out.mean(dim=0) std = out.std(dim=0) inv_stdv = init_scale / (std + 1e-6) self.log_scale.add_(inv_stdv.log()) self.bias.add_(-mean).mul_(inv_stdv) return self.forward(data, mask=mask) @overrides def extra_repr(self): return 'inverse={}, in_features={}'.format(self.inverse, self.in_features) @classmethod def from_params(cls, params: Dict) -> "ActNorm1dFlow": return ActNorm1dFlow(**params)
class MaskedLinear(nn.Module): """ masked linear module with weight normalization """ def __init__(self, in_features, out_features, mask_type, total_units, max_units=None, bias=True): """ Args: in_features: number of units in the inputs out_features: number of units in the outputs. max_units: the list containing the maximum units each input unit depends on. mask_type: type of the masked linear. total_units: the total number of units to assign. bias: using bias vector. """ super(MaskedLinear, self).__init__() self.in_features = in_features self.out_features = out_features self.weight_v = Parameter(torch.Tensor(out_features, in_features)) self.weight_g = Parameter(torch.Tensor(out_features, 1)) if bias: self.bias = Parameter(torch.Tensor(out_features)) else: self.register_parameter('bias', None) layer_type, order = mask_type self.layer_type = layer_type self.order = order assert layer_type in {'input-hidden', 'hidden-hidden', 'hidden-output', 'input-output'} assert order in {'A', 'B'} self.register_buffer('mask', self.weight_v.data.clone()) # override the max_units for input layer if layer_type.startswith('input'): max_units = np.arange(in_features) + 1 else: assert max_units is not None and len(max_units) == in_features if layer_type.endswith('output'): assert out_features > total_units self.max_units = np.arange(out_features) self.max_units[total_units:] = total_units else: units_per_units = float(total_units) / out_features self.max_units = np.zeros(out_features, dtype=np.int32) for i in range(out_features): self.max_units[i] = np.ceil((i + 1) * units_per_units) mask = np.zeros([out_features, in_features], dtype=np.float32) for i in range(out_features): for j in range(in_features): mask[i, j] = float(self.max_units[i] >= max_units[j]) # reverse order if order == 'B': reverse_mask = mask[::-1, :] reverse_mask = reverse_mask[:, ::-1] mask = np.copy(reverse_mask) self.mask.copy_(torch.from_numpy(mask).float()) self.reset_parameters() self._init = True def reset_parameters(self): nn.init.normal_(self.weight_v, mean=0.0, std=0.05) self.weight_v.data.mul_(self.mask) _norm = norm(self.weight_v, 0).data + 1e-8 self.weight_g.data.copy_(_norm.log()) if self.bias is not None: nn.init.constant_(self.bias, 0.) def initialize(self, x, init_scale=1.0): with torch.no_grad(): # [batch, out_features] out = self(x) # [out_features] mean = out.mean(dim=0) std = out.std(dim=0) std = std + std.le(0).float() inv_stdv = init_scale / (std + 1e-6) self.weight_g.add_(inv_stdv.log().unsqueeze(1)) if self.bias is not None: self.bias.add_(-mean).mul_(inv_stdv) return self(x) def forward(self, input): self.weight_v.data.mul_(self.mask) _norm = norm(self.weight_v, 0) + 1e-8 weight = self.weight_v * (self.weight_g.exp() / _norm) return F.linear(input, weight, self.bias) @overrides def extra_repr(self): return 'in_features={}, out_features={}, bias={}, type={}, order={}'.format( self.in_features, self.out_features, self.bias is not None, self.layer_type, self.order )