Exemplo n.º 1
0
    def __init__(self,
                 in_channels,
                 max_channels,
                 num_convs=5,
                 fc_in_channels=None,
                 fc_out_channels=1024,
                 kernel_size=5,
                 conv_cfg=None,
                 norm_cfg=None,
                 act_cfg=dict(type='ReLU'),
                 out_act_cfg=dict(type='ReLU'),
                 with_input_norm=True,
                 with_out_convs=False,
                 with_spectral_norm=False,
                 **kwargs):
        super(MultiLayerDiscriminator, self).__init__()
        if fc_in_channels is not None:
            assert fc_in_channels > 0

        self.max_channels = max_channels
        self.with_fc = fc_in_channels is not None
        self.num_convs = num_convs
        self.with_out_act = out_act_cfg is not None
        self.with_out_convs = with_out_convs

        cur_channels = in_channels
        for i in range(num_convs):
            out_ch = min(64 * 2**i, max_channels)
            norm_cfg_ = norm_cfg
            act_cfg_ = act_cfg
            if i == 0 and not with_input_norm:
                norm_cfg_ = None
            elif (i == num_convs - 1 and not self.with_fc
                  and not self.with_out_convs):
                norm_cfg_ = None
                act_cfg_ = out_act_cfg
            self.add_module(
                f'conv{i + 1}',
                ConvModule(cur_channels,
                           out_ch,
                           kernel_size=kernel_size,
                           stride=2,
                           padding=kernel_size // 2,
                           norm_cfg=norm_cfg_,
                           act_cfg=act_cfg_,
                           with_spectral_norm=with_spectral_norm,
                           **kwargs))
            cur_channels = out_ch

        if self.with_out_convs:
            cur_channels = min(64 * 2**(num_convs - 1), max_channels)
            out_ch = min(64 * 2**num_convs, max_channels)
            self.add_module(
                f'conv{num_convs + 1}',
                ConvModule(cur_channels,
                           out_ch,
                           kernel_size,
                           stride=1,
                           padding=kernel_size // 2,
                           norm_cfg=norm_cfg,
                           act_cfg=act_cfg,
                           with_spectral_norm=with_spectral_norm,
                           **kwargs))
            self.add_module(
                f'conv{num_convs + 2}',
                ConvModule(out_ch,
                           1,
                           kernel_size,
                           stride=1,
                           padding=kernel_size // 2,
                           act_cfg=None,
                           with_spectral_norm=with_spectral_norm,
                           **kwargs))

        if self.with_fc:
            self.fc = LinearModule(fc_in_channels,
                                   fc_out_channels,
                                   bias=True,
                                   act_cfg=out_act_cfg,
                                   with_spectral_norm=with_spectral_norm)
Exemplo n.º 2
0
def test_linear_module():
    linear = LinearModule(10, 20)
    linear.init_weights()
    x = torch.rand((3, 10))
    assert linear.with_bias
    assert not linear.with_spectral_norm
    assert linear.out_features == 20
    assert linear.in_features == 10
    assert isinstance(linear.activate, nn.ReLU)

    y = linear(x)
    assert y.shape == (3, 20)

    linear = LinearModule(10, 20, act_cfg=None, with_spectral_norm=True)

    assert hasattr(linear.linear, 'weight_orig')
    assert not linear.with_activation
    y = linear(x)
    assert y.shape == (3, 20)

    linear = LinearModule(10,
                          20,
                          act_cfg=dict(type='LeakyReLU'),
                          with_spectral_norm=True)
    y = linear(x)
    assert y.shape == (3, 20)
    assert isinstance(linear.activate, nn.LeakyReLU)

    linear = LinearModule(10,
                          20,
                          bias=False,
                          act_cfg=None,
                          with_spectral_norm=True)
    y = linear(x)
    assert y.shape == (3, 20)
    assert not linear.with_bias

    linear = LinearModule(10,
                          20,
                          bias=False,
                          act_cfg=None,
                          with_spectral_norm=True,
                          order=('act', 'linear'))

    assert linear.order == ('act', 'linear')