Esempio n. 1
0
 def test_layer_is_well_behaved(self):
     for gated_conv in [False, True]:
         with self.subTest(gated_conv=gated_conv):
             x = torch.randn(10, 3, 8, 8)
             module = DenseNet(in_channels=3, out_channels=6, num_blocks=1, mid_channels=12,
                               depth=2, growth=4, dropout=0.0, gated_conv=gated_conv, zero_init=False)
             self.assert_layer_is_well_behaved(module, x)
Esempio n. 2
0
    def __init__(self, in_channels, num_context, num_blocks, mid_channels,
                 depth, growth, dropout, gated_conv, coupling_network):

        assert in_channels % 2 == 0

        if coupling_network == "densenet":
            net = nn.Sequential(
                DenseNet(in_channels=in_channels // 2 + num_context,
                         out_channels=in_channels,
                         num_blocks=num_blocks,
                         mid_channels=mid_channels,
                         depth=depth,
                         growth=growth,
                         dropout=dropout,
                         gated_conv=gated_conv,
                         zero_init=True),
                ElementwiseParams2d(2, mode='sequential'))
        elif coupling_network == "conv":
            net = nn.Sequential(
                ConvNet(in_channels=in_channels // 2 + num_context,
                        out_channels=in_channels,
                        mid_channels=mid_channels,
                        num_layers=depth,
                        activation='relu'),
                ElementwiseParams2d(2, mode='sequential'))
        else:
            raise ValueError(f"Unknown coupling network {coupling_network}")

        super(ConditionalCoupling,
              self).__init__(coupling_net=net, scale_fn=scale_fn("tanh_exp"))
Esempio n. 3
0
    def __init__(self,
                 x_size,
                 y_size,
                 coupling_network,
                 mid_channels,
                 depth,
                 num_blocks=None,
                 dropout=None,
                 gated_conv=None,
                 checkerboard=False,
                 flip=False):

        if checkerboard:
            in_channels = y_size[0] + x_size[0]
            out_channels = y_size[0] * 2
            split_dim = 3
            assert x_size[1] == y_size[1] and x_size[2] == y_size[2] // 2
        else:
            in_channels = y_size[0] // 2 + x_size[0]
            out_channels = y_size[0]
            split_dim = 1
            assert x_size[1] == y_size[1] and x_size[2] == y_size[2]
            assert y_size[
                0] % 2 == 0, f"High-resolution has shape {y_size} with channels not evenly divisible"

        if coupling_network == "densenet":
            coupling_net = nn.Sequential(
                DenseNet(in_channels=in_channels,
                         out_channels=out_channels,
                         num_blocks=num_blocks,
                         mid_channels=mid_channels,
                         depth=depth,
                         growth=mid_channels,
                         dropout=dropout,
                         gated_conv=gated_conv,
                         zero_init=True),
                ElementwiseParams2d(2, mode='sequential'))

        elif coupling_network == "conv":
            coupling_net = nn.Sequential(
                ConvNet(in_channels=in_channels,
                        out_channels=out_channels,
                        mid_channels=mid_channels,
                        num_layers=depth,
                        weight_norm=True,
                        activation='relu'),
                ElementwiseParams2d(2, mode='sequential'))

        else:
            raise ValueError(f"Unknown coupling network {coupling_network}")

        super(SRCoupling, self).__init__(coupling_net=coupling_net,
                                         scale_fn=scale_fn("tanh_exp"),
                                         split_dim=split_dim,
                                         flip=flip)
Esempio n. 4
0
def net(channels):
    return nn.Sequential(
        DenseNet(in_channels=channels // 2,
                 out_channels=channels,
                 num_blocks=1,
                 mid_channels=64,
                 depth=8,
                 growth=16,
                 dropout=0.0,
                 gated_conv=True,
                 zero_init=True), ElementwiseParams2d(2))
Esempio n. 5
0
    def __init__(self,
                 in_channels,
                 num_context,
                 num_blocks,
                 mid_channels,
                 depth,
                 dropout,
                 gated_conv,
                 coupling_network,
                 checkerboard=False,
                 flip=False):

        if checkerboard:
            num_in = in_channels + num_context
            num_out = in_channels * 2
            split_dim = 3
        else:
            num_in = in_channels // 2 + num_context
            num_out = in_channels
            split_dim = 1

        assert in_channels % 2 == 0 or split_dim != 1, f"in_channels = {in_channels} not evenly divisible"

        if coupling_network == "densenet":
            net = nn.Sequential(
                DenseNet(in_channels=num_in,
                         out_channels=num_out,
                         num_blocks=num_blocks,
                         mid_channels=mid_channels,
                         depth=depth,
                         growth=mid_channels,
                         dropout=dropout,
                         gated_conv=gated_conv,
                         zero_init=True),
                ElementwiseParams2d(2, mode='sequential'))
        elif coupling_network == "conv":
            net = nn.Sequential(
                ConvNet(in_channels=num_in,
                        out_channels=num_out,
                        mid_channels=mid_channels,
                        num_layers=depth,
                        activation='relu'),
                ElementwiseParams2d(2, mode='sequential'))
        else:
            raise ValueError(f"Unknown coupling network {coupling_network}")

        super(ConditionalCoupling,
              self).__init__(coupling_net=net,
                             scale_fn=scale_fn("tanh_exp"),
                             split_dim=split_dim,
                             flip=flip)
Esempio n. 6
0
    def __init__(self, in_channels, num_blocks, mid_channels, depth, growth, dropout, gated_conv):

        assert in_channels % 2 == 0

        net = nn.Sequential(DenseNet(in_channels=in_channels//2,
                                     out_channels=in_channels,
                                     num_blocks=num_blocks,
                                     mid_channels=mid_channels,
                                     depth=depth,
                                     growth=growth,
                                     dropout=dropout,
                                     gated_conv=gated_conv,
                                     zero_init=True),
                            ElementwiseParams2d(2, mode='sequential'))
        super(Coupling, self).__init__(coupling_net=net)
Esempio n. 7
0
    def __init__(self,
                 num_bits,
                 in_channels,
                 out_channels,
                 mid_channels,
                 num_blocks,
                 depth,
                 dropout=0.0):
        super(ContextInit, self).__init__()
        self.dequant = UniformDequantization(num_bits=num_bits)
        self.shift = ScalarAffineBijection(shift=-0.5)

        self.encode = None
        if mid_channels > 0 and num_blocks > 0 and depth > 0:
            self.encode = DenseNet(in_channels=in_channels,
                                   out_channels=out_channels,
                                   num_blocks=num_blocks,
                                   mid_channels=mid_channels,
                                   depth=depth,
                                   growth=mid_channels,
                                   dropout=dropout,
                                   gated_conv=False,
                                   zero_init=False)
Esempio n. 8
0
 def test_zero_init(self):
     x = torch.randn(10, 3, 8, 8)
     module = DenseNet(in_channels=3, out_channels=6, num_blocks=1, mid_channels=12,
                       depth=2, growth=4, dropout=0.0, gated_conv=False, zero_init=True)
     y = module(x)
     self.assertEqual(y, torch.zeros(10, 6, 8, 8))