def test_layer_is_well_behaved(self): for gated_conv in [False, True]: with self.subTest(gated_conv=gated_conv): x = torch.randn(10, 3, 8, 8) module = DenseNet(in_channels=3, out_channels=6, num_blocks=1, mid_channels=12, depth=2, growth=4, dropout=0.0, gated_conv=gated_conv, zero_init=False) self.assert_layer_is_well_behaved(module, x)
def __init__(self, in_channels, num_context, num_blocks, mid_channels, depth, growth, dropout, gated_conv, coupling_network): assert in_channels % 2 == 0 if coupling_network == "densenet": net = nn.Sequential( DenseNet(in_channels=in_channels // 2 + num_context, out_channels=in_channels, num_blocks=num_blocks, mid_channels=mid_channels, depth=depth, growth=growth, dropout=dropout, gated_conv=gated_conv, zero_init=True), ElementwiseParams2d(2, mode='sequential')) elif coupling_network == "conv": net = nn.Sequential( ConvNet(in_channels=in_channels // 2 + num_context, out_channels=in_channels, mid_channels=mid_channels, num_layers=depth, activation='relu'), ElementwiseParams2d(2, mode='sequential')) else: raise ValueError(f"Unknown coupling network {coupling_network}") super(ConditionalCoupling, self).__init__(coupling_net=net, scale_fn=scale_fn("tanh_exp"))
def __init__(self, x_size, y_size, coupling_network, mid_channels, depth, num_blocks=None, dropout=None, gated_conv=None, checkerboard=False, flip=False): if checkerboard: in_channels = y_size[0] + x_size[0] out_channels = y_size[0] * 2 split_dim = 3 assert x_size[1] == y_size[1] and x_size[2] == y_size[2] // 2 else: in_channels = y_size[0] // 2 + x_size[0] out_channels = y_size[0] split_dim = 1 assert x_size[1] == y_size[1] and x_size[2] == y_size[2] assert y_size[ 0] % 2 == 0, f"High-resolution has shape {y_size} with channels not evenly divisible" if coupling_network == "densenet": coupling_net = nn.Sequential( DenseNet(in_channels=in_channels, out_channels=out_channels, num_blocks=num_blocks, mid_channels=mid_channels, depth=depth, growth=mid_channels, dropout=dropout, gated_conv=gated_conv, zero_init=True), ElementwiseParams2d(2, mode='sequential')) elif coupling_network == "conv": coupling_net = nn.Sequential( ConvNet(in_channels=in_channels, out_channels=out_channels, mid_channels=mid_channels, num_layers=depth, weight_norm=True, activation='relu'), ElementwiseParams2d(2, mode='sequential')) else: raise ValueError(f"Unknown coupling network {coupling_network}") super(SRCoupling, self).__init__(coupling_net=coupling_net, scale_fn=scale_fn("tanh_exp"), split_dim=split_dim, flip=flip)
def net(channels): return nn.Sequential( DenseNet(in_channels=channels // 2, out_channels=channels, num_blocks=1, mid_channels=64, depth=8, growth=16, dropout=0.0, gated_conv=True, zero_init=True), ElementwiseParams2d(2))
def __init__(self, in_channels, num_context, num_blocks, mid_channels, depth, dropout, gated_conv, coupling_network, checkerboard=False, flip=False): if checkerboard: num_in = in_channels + num_context num_out = in_channels * 2 split_dim = 3 else: num_in = in_channels // 2 + num_context num_out = in_channels split_dim = 1 assert in_channels % 2 == 0 or split_dim != 1, f"in_channels = {in_channels} not evenly divisible" if coupling_network == "densenet": net = nn.Sequential( DenseNet(in_channels=num_in, out_channels=num_out, num_blocks=num_blocks, mid_channels=mid_channels, depth=depth, growth=mid_channels, dropout=dropout, gated_conv=gated_conv, zero_init=True), ElementwiseParams2d(2, mode='sequential')) elif coupling_network == "conv": net = nn.Sequential( ConvNet(in_channels=num_in, out_channels=num_out, mid_channels=mid_channels, num_layers=depth, activation='relu'), ElementwiseParams2d(2, mode='sequential')) else: raise ValueError(f"Unknown coupling network {coupling_network}") super(ConditionalCoupling, self).__init__(coupling_net=net, scale_fn=scale_fn("tanh_exp"), split_dim=split_dim, flip=flip)
def __init__(self, in_channels, num_blocks, mid_channels, depth, growth, dropout, gated_conv): assert in_channels % 2 == 0 net = nn.Sequential(DenseNet(in_channels=in_channels//2, out_channels=in_channels, num_blocks=num_blocks, mid_channels=mid_channels, depth=depth, growth=growth, dropout=dropout, gated_conv=gated_conv, zero_init=True), ElementwiseParams2d(2, mode='sequential')) super(Coupling, self).__init__(coupling_net=net)
def __init__(self, num_bits, in_channels, out_channels, mid_channels, num_blocks, depth, dropout=0.0): super(ContextInit, self).__init__() self.dequant = UniformDequantization(num_bits=num_bits) self.shift = ScalarAffineBijection(shift=-0.5) self.encode = None if mid_channels > 0 and num_blocks > 0 and depth > 0: self.encode = DenseNet(in_channels=in_channels, out_channels=out_channels, num_blocks=num_blocks, mid_channels=mid_channels, depth=depth, growth=mid_channels, dropout=dropout, gated_conv=False, zero_init=False)
def test_zero_init(self): x = torch.randn(10, 3, 8, 8) module = DenseNet(in_channels=3, out_channels=6, num_blocks=1, mid_channels=12, depth=2, growth=4, dropout=0.0, gated_conv=False, zero_init=True) y = module(x) self.assertEqual(y, torch.zeros(10, 6, 8, 8))