Exemplo n.º 1
0
    def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):
        m = []
        if (scale & (scale - 1)) == 0:  # Is scale = 2^n?
            for _ in range(int(math.log(scale, 2))):
                m.append(conv(n_feats, 4 * n_feats, 3, bias))
                m.append(nn.PixelShuffle(2))
                if bn: m.append(nn.BatchNorm2D(n_feats))

                if act == 'relu':
                    m.append(nn.ReLU())
                elif act == 'prelu':
                    m.append(nn.PReLU(n_feats))

        elif scale == 3:
            m.append(conv(n_feats, 9 * n_feats, 3, bias))
            m.append(nn.PixelShuffle(3))
            if bn: m.append(nn.BatchNorm2D(n_feats))

            if act == 'relu':
                m.append(nn.ReLU())
            elif act == 'prelu':
                m.append(nn.PReLU(n_feats))
        else:
            raise NotImplementedError

        super(Upsampler, self).__init__(*m)
Exemplo n.º 2
0
    def __init__(self, n_channels, scale, group=1):
        super(_UpsampleBlock, self).__init__()

        modules = []
        if scale == 2 or scale == 4 or scale == 8:
            for _ in range(int(math.log(scale, 2))):
                modules += [
                    nn.Conv2D(n_channels,
                              4 * n_channels,
                              3,
                              1,
                              1,
                              groups=group)
                ]
                modules += [nn.PixelShuffle(2)]
        elif scale == 3:
            modules += [
                nn.Conv2D(n_channels, 9 * n_channels, 3, 1, 1, groups=group)
            ]
            modules += [nn.PixelShuffle(3)]

        self.body = nn.Sequential(*modules)
Exemplo n.º 3
0
    def __init__(self,
                 in_nf=3,
                 out_nf=3,
                 scale_factor=4,
                 nf=64,
                 nframes=5,
                 groups=8,
                 front_RBs=5,
                 back_RBs=10,
                 center=None,
                 predeblur=False,
                 HR_in=False,
                 w_TSA=True,
                 TSA_only=False):
        super(EDVRNet, self).__init__()
        self.in_nf = in_nf
        self.out_nf = out_nf
        self.scale_factor = scale_factor
        self.nf = nf
        self.nframes = nframes
        self.groups = groups
        self.front_RBs = front_RBs
        self.back_RBs = back_RBs
        self.center = nframes // 2 if center is None else center
        self.predeblur = True if predeblur else False
        self.HR_in = True if HR_in else False
        self.w_TSA = True if w_TSA else False

        self.Leaky_relu = nn.LeakyReLU(negative_slope=0.1)
        if self.predeblur:
            self.pre_deblur = PredeblurResNetPyramid(in_nf=self.in_nf,
                                                     nf=self.nf,
                                                     HR_in=self.HR_in)
            self.cov_1 = nn.Conv2D(in_channels=self.nf,
                                   out_channels=self.nf,
                                   kernel_size=1,
                                   stride=1)
        else:
            if self.HR_in:
                self.conv_first_1 = nn.Conv2D(in_channels=self.in_nf,
                                              out_channels=self.nf,
                                              kernel_size=3,
                                              stride=1,
                                              padding=1)
                self.conv_first_2 = nn.Conv2D(in_channels=self.nf,
                                              out_channels=self.nf,
                                              kernel_size=3,
                                              stride=2,
                                              padding=1)
                self.conv_first_3 = nn.Conv2D(in_channels=self.nf,
                                              out_channels=self.nf,
                                              kernel_size=3,
                                              stride=2,
                                              padding=1)
            else:
                self.conv_first = nn.Conv2D(in_channels=self.in_nf,
                                            out_channels=self.nf,
                                            kernel_size=3,
                                            stride=1,
                                            padding=1)

        #feature extraction module
        self.feature_extractor = MakeMultiBlocks(ResidualBlockNoBN,
                                                 self.front_RBs, self.nf)
        self.fea_L2_conv1 = nn.Conv2D(in_channels=self.nf,
                                      out_channels=self.nf,
                                      kernel_size=3,
                                      stride=2,
                                      padding=1)
        self.fea_L2_conv2 = nn.Conv2D(in_channels=self.nf,
                                      out_channels=self.nf,
                                      kernel_size=3,
                                      stride=1,
                                      padding=1)
        self.fea_L3_conv1 = nn.Conv2D(
            in_channels=self.nf,
            out_channels=self.nf,
            kernel_size=3,
            stride=2,
            padding=1,
        )
        self.fea_L3_conv2 = nn.Conv2D(in_channels=self.nf,
                                      out_channels=self.nf,
                                      kernel_size=3,
                                      stride=1,
                                      padding=1)

        #PCD alignment module
        self.PCDModule = PCDAlign(nf=self.nf, groups=self.groups)

        #TSA Fusion module
        if self.w_TSA:
            self.TSAModule = TSAFusion(nf=self.nf,
                                       nframes=self.nframes,
                                       center=self.center)
        else:
            self.TSAModule = nn.Conv2D(in_channels=self.nframes * self.nf,
                                       out_channels=self.nf,
                                       kernel_size=1,
                                       stride=1)

        #reconstruction module
        self.reconstructor = MakeMultiBlocks(ResidualBlockNoBN, self.back_RBs,
                                             self.nf)
        self.upconv1 = nn.Conv2D(in_channels=self.nf,
                                 out_channels=4 * self.nf,
                                 kernel_size=3,
                                 stride=1,
                                 padding=1)
        self.pixel_shuffle = nn.PixelShuffle(2)
        self.upconv2 = nn.Conv2D(in_channels=self.nf,
                                 out_channels=4 * self.nf,
                                 kernel_size=3,
                                 stride=1,
                                 padding=1)
        self.HRconv = nn.Conv2D(in_channels=self.nf,
                                out_channels=self.nf,
                                kernel_size=3,
                                stride=1,
                                padding=1)
        self.conv_last = nn.Conv2D(in_channels=self.nf,
                                   out_channels=self.out_nf,
                                   kernel_size=3,
                                   stride=1,
                                   padding=1)
        self.upsample = nn.Upsample(scale_factor=self.scale_factor,
                                    mode="bilinear",
                                    align_corners=False,
                                    align_mode=0)
Exemplo n.º 4
0
    def func_test_layer_str(self):
        module = nn.ELU(0.2)
        self.assertEqual(str(module), 'ELU(alpha=0.2)')

        module = nn.CELU(0.2)
        self.assertEqual(str(module), 'CELU(alpha=0.2)')

        module = nn.GELU(True)
        self.assertEqual(str(module), 'GELU(approximate=True)')

        module = nn.Hardshrink()
        self.assertEqual(str(module), 'Hardshrink(threshold=0.5)')

        module = nn.Hardswish(name="Hardswish")
        self.assertEqual(str(module), 'Hardswish(name=Hardswish)')

        module = nn.Tanh(name="Tanh")
        self.assertEqual(str(module), 'Tanh(name=Tanh)')

        module = nn.Hardtanh(name="Hardtanh")
        self.assertEqual(str(module),
                         'Hardtanh(min=-1.0, max=1.0, name=Hardtanh)')

        module = nn.PReLU(1, 0.25, name="PReLU", data_format="NCHW")
        self.assertEqual(
            str(module),
            'PReLU(num_parameters=1, data_format=NCHW, init=0.25, dtype=float32, name=PReLU)'
        )

        module = nn.ReLU()
        self.assertEqual(str(module), 'ReLU()')

        module = nn.ReLU6()
        self.assertEqual(str(module), 'ReLU6()')

        module = nn.SELU()
        self.assertEqual(
            str(module),
            'SELU(scale=1.0507009873554805, alpha=1.6732632423543772)')

        module = nn.LeakyReLU()
        self.assertEqual(str(module), 'LeakyReLU(negative_slope=0.01)')

        module = nn.Sigmoid()
        self.assertEqual(str(module), 'Sigmoid()')

        module = nn.Hardsigmoid()
        self.assertEqual(str(module), 'Hardsigmoid()')

        module = nn.Softplus()
        self.assertEqual(str(module), 'Softplus(beta=1, threshold=20)')

        module = nn.Softshrink()
        self.assertEqual(str(module), 'Softshrink(threshold=0.5)')

        module = nn.Softsign()
        self.assertEqual(str(module), 'Softsign()')

        module = nn.Swish()
        self.assertEqual(str(module), 'Swish()')

        module = nn.Tanhshrink()
        self.assertEqual(str(module), 'Tanhshrink()')

        module = nn.ThresholdedReLU()
        self.assertEqual(str(module), 'ThresholdedReLU(threshold=1.0)')

        module = nn.LogSigmoid()
        self.assertEqual(str(module), 'LogSigmoid()')

        module = nn.Softmax()
        self.assertEqual(str(module), 'Softmax(axis=-1)')

        module = nn.LogSoftmax()
        self.assertEqual(str(module), 'LogSoftmax(axis=-1)')

        module = nn.Maxout(groups=2)
        self.assertEqual(str(module), 'Maxout(groups=2, axis=1)')

        module = nn.Linear(2, 4, name='linear')
        self.assertEqual(
            str(module),
            'Linear(in_features=2, out_features=4, dtype=float32, name=linear)'
        )

        module = nn.Upsample(size=[12, 12])
        self.assertEqual(
            str(module),
            'Upsample(size=[12, 12], mode=nearest, align_corners=False, align_mode=0, data_format=NCHW)'
        )

        module = nn.UpsamplingNearest2D(size=[12, 12])
        self.assertEqual(
            str(module),
            'UpsamplingNearest2D(size=[12, 12], data_format=NCHW)')

        module = nn.UpsamplingBilinear2D(size=[12, 12])
        self.assertEqual(
            str(module),
            'UpsamplingBilinear2D(size=[12, 12], data_format=NCHW)')

        module = nn.Bilinear(in1_features=5, in2_features=4, out_features=1000)
        self.assertEqual(
            str(module),
            'Bilinear(in1_features=5, in2_features=4, out_features=1000, dtype=float32)'
        )

        module = nn.Dropout(p=0.5)
        self.assertEqual(str(module),
                         'Dropout(p=0.5, axis=None, mode=upscale_in_train)')

        module = nn.Dropout2D(p=0.5)
        self.assertEqual(str(module), 'Dropout2D(p=0.5, data_format=NCHW)')

        module = nn.Dropout3D(p=0.5)
        self.assertEqual(str(module), 'Dropout3D(p=0.5, data_format=NCDHW)')

        module = nn.AlphaDropout(p=0.5)
        self.assertEqual(str(module), 'AlphaDropout(p=0.5)')

        module = nn.Pad1D(padding=[1, 2], mode='constant')
        self.assertEqual(
            str(module),
            'Pad1D(padding=[1, 2], mode=constant, value=0.0, data_format=NCL)')

        module = nn.Pad2D(padding=[1, 0, 1, 2], mode='constant')
        self.assertEqual(
            str(module),
            'Pad2D(padding=[1, 0, 1, 2], mode=constant, value=0.0, data_format=NCHW)'
        )

        module = nn.ZeroPad2D(padding=[1, 0, 1, 2])
        self.assertEqual(str(module),
                         'ZeroPad2D(padding=[1, 0, 1, 2], data_format=NCHW)')

        module = nn.Pad3D(padding=[1, 0, 1, 2, 0, 0], mode='constant')
        self.assertEqual(
            str(module),
            'Pad3D(padding=[1, 0, 1, 2, 0, 0], mode=constant, value=0.0, data_format=NCDHW)'
        )

        module = nn.CosineSimilarity(axis=0)
        self.assertEqual(str(module), 'CosineSimilarity(axis=0, eps=1e-08)')

        module = nn.Embedding(10, 3, sparse=True)
        self.assertEqual(str(module), 'Embedding(10, 3, sparse=True)')

        module = nn.Conv1D(3, 2, 3)
        self.assertEqual(str(module),
                         'Conv1D(3, 2, kernel_size=[3], data_format=NCL)')

        module = nn.Conv1DTranspose(2, 1, 2)
        self.assertEqual(
            str(module),
            'Conv1DTranspose(2, 1, kernel_size=[2], data_format=NCL)')

        module = nn.Conv2D(4, 6, (3, 3))
        self.assertEqual(str(module),
                         'Conv2D(4, 6, kernel_size=[3, 3], data_format=NCHW)')

        module = nn.Conv2DTranspose(4, 6, (3, 3))
        self.assertEqual(
            str(module),
            'Conv2DTranspose(4, 6, kernel_size=[3, 3], data_format=NCHW)')

        module = nn.Conv3D(4, 6, (3, 3, 3))
        self.assertEqual(
            str(module),
            'Conv3D(4, 6, kernel_size=[3, 3, 3], data_format=NCDHW)')

        module = nn.Conv3DTranspose(4, 6, (3, 3, 3))
        self.assertEqual(
            str(module),
            'Conv3DTranspose(4, 6, kernel_size=[3, 3, 3], data_format=NCDHW)')

        module = nn.PairwiseDistance()
        self.assertEqual(str(module), 'PairwiseDistance(p=2.0)')

        module = nn.InstanceNorm1D(2)
        self.assertEqual(str(module),
                         'InstanceNorm1D(num_features=2, epsilon=1e-05)')

        module = nn.InstanceNorm2D(2)
        self.assertEqual(str(module),
                         'InstanceNorm2D(num_features=2, epsilon=1e-05)')

        module = nn.InstanceNorm3D(2)
        self.assertEqual(str(module),
                         'InstanceNorm3D(num_features=2, epsilon=1e-05)')

        module = nn.GroupNorm(num_channels=6, num_groups=6)
        self.assertEqual(
            str(module),
            'GroupNorm(num_groups=6, num_channels=6, epsilon=1e-05)')

        module = nn.LayerNorm([2, 2, 3])
        self.assertEqual(
            str(module),
            'LayerNorm(normalized_shape=[2, 2, 3], epsilon=1e-05)')

        module = nn.BatchNorm1D(1)
        self.assertEqual(
            str(module),
            'BatchNorm1D(num_features=1, momentum=0.9, epsilon=1e-05, data_format=NCL)'
        )

        module = nn.BatchNorm2D(1)
        self.assertEqual(
            str(module),
            'BatchNorm2D(num_features=1, momentum=0.9, epsilon=1e-05)')

        module = nn.BatchNorm3D(1)
        self.assertEqual(
            str(module),
            'BatchNorm3D(num_features=1, momentum=0.9, epsilon=1e-05, data_format=NCDHW)'
        )

        module = nn.SyncBatchNorm(2)
        self.assertEqual(
            str(module),
            'SyncBatchNorm(num_features=2, momentum=0.9, epsilon=1e-05)')

        module = nn.LocalResponseNorm(size=5)
        self.assertEqual(
            str(module),
            'LocalResponseNorm(size=5, alpha=0.0001, beta=0.75, k=1.0)')

        module = nn.AvgPool1D(kernel_size=2, stride=2, padding=0)
        self.assertEqual(str(module),
                         'AvgPool1D(kernel_size=2, stride=2, padding=0)')

        module = nn.AvgPool2D(kernel_size=2, stride=2, padding=0)
        self.assertEqual(str(module),
                         'AvgPool2D(kernel_size=2, stride=2, padding=0)')

        module = nn.AvgPool3D(kernel_size=2, stride=2, padding=0)
        self.assertEqual(str(module),
                         'AvgPool3D(kernel_size=2, stride=2, padding=0)')

        module = nn.MaxPool1D(kernel_size=2, stride=2, padding=0)
        self.assertEqual(str(module),
                         'MaxPool1D(kernel_size=2, stride=2, padding=0)')

        module = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
        self.assertEqual(str(module),
                         'MaxPool2D(kernel_size=2, stride=2, padding=0)')

        module = nn.MaxPool3D(kernel_size=2, stride=2, padding=0)
        self.assertEqual(str(module),
                         'MaxPool3D(kernel_size=2, stride=2, padding=0)')

        module = nn.AdaptiveAvgPool1D(output_size=16)
        self.assertEqual(str(module), 'AdaptiveAvgPool1D(output_size=16)')

        module = nn.AdaptiveAvgPool2D(output_size=3)
        self.assertEqual(str(module), 'AdaptiveAvgPool2D(output_size=3)')

        module = nn.AdaptiveAvgPool3D(output_size=3)
        self.assertEqual(str(module), 'AdaptiveAvgPool3D(output_size=3)')

        module = nn.AdaptiveMaxPool1D(output_size=16, return_mask=True)
        self.assertEqual(
            str(module), 'AdaptiveMaxPool1D(output_size=16, return_mask=True)')

        module = nn.AdaptiveMaxPool2D(output_size=3, return_mask=True)
        self.assertEqual(str(module),
                         'AdaptiveMaxPool2D(output_size=3, return_mask=True)')

        module = nn.AdaptiveMaxPool3D(output_size=3, return_mask=True)
        self.assertEqual(str(module),
                         'AdaptiveMaxPool3D(output_size=3, return_mask=True)')

        module = nn.SimpleRNNCell(16, 32)
        self.assertEqual(str(module), 'SimpleRNNCell(16, 32)')

        module = nn.LSTMCell(16, 32)
        self.assertEqual(str(module), 'LSTMCell(16, 32)')

        module = nn.GRUCell(16, 32)
        self.assertEqual(str(module), 'GRUCell(16, 32)')

        module = nn.PixelShuffle(3)
        self.assertEqual(str(module), 'PixelShuffle(upscale_factor=3)')

        module = nn.SimpleRNN(16, 32, 2)
        self.assertEqual(
            str(module),
            'SimpleRNN(16, 32, num_layers=2\n  (0): RNN(\n    (cell): SimpleRNNCell(16, 32)\n  )\n  (1): RNN(\n    (cell): SimpleRNNCell(32, 32)\n  )\n)'
        )

        module = nn.LSTM(16, 32, 2)
        self.assertEqual(
            str(module),
            'LSTM(16, 32, num_layers=2\n  (0): RNN(\n    (cell): LSTMCell(16, 32)\n  )\n  (1): RNN(\n    (cell): LSTMCell(32, 32)\n  )\n)'
        )

        module = nn.GRU(16, 32, 2)
        self.assertEqual(
            str(module),
            'GRU(16, 32, num_layers=2\n  (0): RNN(\n    (cell): GRUCell(16, 32)\n  )\n  (1): RNN(\n    (cell): GRUCell(32, 32)\n  )\n)'
        )

        module1 = nn.Sequential(
            ('conv1', nn.Conv2D(1, 20, 5)), ('relu1', nn.ReLU()),
            ('conv2', nn.Conv2D(20, 64, 5)), ('relu2', nn.ReLU()))
        self.assertEqual(
            str(module1),
            'Sequential(\n  '\
            '(conv1): Conv2D(1, 20, kernel_size=[5, 5], data_format=NCHW)\n  '\
            '(relu1): ReLU()\n  '\
            '(conv2): Conv2D(20, 64, kernel_size=[5, 5], data_format=NCHW)\n  '\
            '(relu2): ReLU()\n)'
        )

        module2 = nn.Sequential(
            nn.Conv3DTranspose(4, 6, (3, 3, 3)),
            nn.AvgPool3D(kernel_size=2, stride=2, padding=0),
            nn.Tanh(name="Tanh"), module1, nn.Conv3D(4, 6, (3, 3, 3)),
            nn.MaxPool3D(kernel_size=2, stride=2, padding=0), nn.GELU(True))
        self.assertEqual(
            str(module2),
            'Sequential(\n  '\
            '(0): Conv3DTranspose(4, 6, kernel_size=[3, 3, 3], data_format=NCDHW)\n  '\
            '(1): AvgPool3D(kernel_size=2, stride=2, padding=0)\n  '\
            '(2): Tanh(name=Tanh)\n  '\
            '(3): Sequential(\n    (conv1): Conv2D(1, 20, kernel_size=[5, 5], data_format=NCHW)\n    (relu1): ReLU()\n'\
            '    (conv2): Conv2D(20, 64, kernel_size=[5, 5], data_format=NCHW)\n    (relu2): ReLU()\n  )\n  '\
            '(4): Conv3D(4, 6, kernel_size=[3, 3, 3], data_format=NCDHW)\n  '\
            '(5): MaxPool3D(kernel_size=2, stride=2, padding=0)\n  '\
            '(6): GELU(approximate=True)\n)'
        )