예제 #1
0
def test_verify_module_non_sequential():
    with pytest.raises(TypeError,
                       match='module must be nn.Sequential to be partitioned'):
        verify_module(nn.Module())
예제 #2
0
    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
                 resolution, z_channels, give_pre_end=False, **ignorekwargs):
        super().__init__()
        self.ch = ch
        self.temb_ch = 0
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        self.resolution = resolution
        self.in_channels = in_channels
        self.give_pre_end = give_pre_end

        # compute in_ch_mult, block_in and curr_res at lowest res
        in_ch_mult = (1,)+tuple(ch_mult)
        block_in = ch*ch_mult[self.num_resolutions-1]
        curr_res = resolution // 2**(self.num_resolutions-1)
        self.z_shape = (1,z_channels,curr_res,curr_res)
        print("Working with z of shape {} = {} dimensions.".format(
            self.z_shape, np.prod(self.z_shape)))

        # z to block_in
        self.conv_in = torch.nn.Conv2d(z_channels,
                                       block_in,
                                       kernel_size=3,
                                       stride=1,
                                       padding=1)

        # middle
        self.mid = nn.Module()
        self.mid.block_1 = ResnetBlock(in_channels=block_in,
                                       out_channels=block_in,
                                       temb_channels=self.temb_ch,
                                       dropout=dropout)
        self.mid.attn_1 = AttnBlock(block_in)
        self.mid.block_2 = ResnetBlock(in_channels=block_in,
                                       out_channels=block_in,
                                       temb_channels=self.temb_ch,
                                       dropout=dropout)

        # upsampling
        self.up = nn.ModuleList()
        for i_level in reversed(range(self.num_resolutions)):
            block = nn.ModuleList()
            attn = nn.ModuleList()
            block_out = ch*ch_mult[i_level]
            for i_block in range(self.num_res_blocks+1):
                block.append(ResnetBlock(in_channels=block_in,
                                         out_channels=block_out,
                                         temb_channels=self.temb_ch,
                                         dropout=dropout))
                block_in = block_out
                if curr_res in attn_resolutions:
                    attn.append(AttnBlock(block_in))
            up = nn.Module()
            up.block = block
            up.attn = attn
            if i_level != 0:
                up.upsample = Upsample(block_in, resamp_with_conv)
                curr_res = curr_res * 2
            self.up.insert(0, up) # prepend to get consistent order

        # end
        self.norm_out = Normalize(block_in)
        self.conv_out = torch.nn.Conv2d(block_in,
                                        out_ch,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)
예제 #3
0
    def __init__(
        self,
        size,
        style_dim,
        n_mlp,
        channel_multiplier=2,
        blur_kernel=[1, 3, 3, 1],
        lr_mlp=0.01,
        constant_input=False,
        checkpoint=None,
        output_size=None,
    ):
        super().__init__()

        self.size = size
        self.style_dim = style_dim

        layers = [PixelNorm()]

        for i in range(n_mlp):
            layers.append(EqualLinear(style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"))

        self.style = nn.Sequential(*layers)

        self.channels = {
            4: 512,
            8: 512,
            16: 512,
            32: 512,
            64: 256 * channel_multiplier,
            128: 128 * channel_multiplier,
            256: 64 * channel_multiplier,
            512: 32 * channel_multiplier,
            1024: 16 * channel_multiplier,
        }

        self.log_size = int(math.log(size, 2))
        self.num_layers = (self.log_size - 2) * 2 + 1
        self.n_latent = self.log_size * 2 - 2

        if constant_input:
            self.input = ConstantInput(self.channels[4])
        else:
            self.input = LatentInput(style_dim, self.channels[4])

        # self.const_manipulation = ManipulationLayer(0)

        layerID = 1
        self.conv1 = StyledConv(
            self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel, layerID=layerID
        )
        self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)

        self.convs = nn.ModuleList()
        self.upsamples = nn.ModuleList()
        self.to_rgbs = nn.ModuleList()
        self.noises = nn.Module()

        in_channel = self.channels[4]

        for layer_idx in range(self.num_layers):
            res = (layer_idx + 5) // 2
            shape = [1, 1, 2 ** res, 2 ** res]
            self.noises.register_buffer(f"noise_{layer_idx}", th.randn(*shape))

        for i in range(3, self.log_size + 1):
            out_channel = self.channels[2 ** i]

            layerID += 1
            self.convs.append(
                StyledConv(
                    in_channel, out_channel, 3, style_dim, upsample=True, blur_kernel=blur_kernel, layerID=layerID
                )
            )

            layerID += 1
            self.convs.append(
                StyledConv(out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel, layerID=layerID)
            )

            self.to_rgbs.append(ToRGB(out_channel, style_dim))

            in_channel = out_channel

        self.truncation_latent = None

        if checkpoint is not None:
            self.load_state_dict(th.load(checkpoint)["g_ema"])

        if not output_size is None and constant_input:
            const = self.input.input
            if size != 1024:
                means = th.zeros(size=(1, 512, int(4 * 1024 / size), int(4 * 1024 / size)))
                const = th.normal(mean=means, std=th.ones_like(means) * const.std())

            _, _, ch, cw = const.shape
            if output_size == 1920:
                layer0 = th.cat(
                    [
                        const[:, :, :, : cw // 2 + 1][:, :, :, list(range(cw // 2, 0, -1))],
                        const,
                        const[:, :, :, cw // 2 :],
                    ],
                    axis=3,
                )
            elif output_size == 512:
                layer0 = const[:, :, ch // 4 : 3 * ch // 4, cw // 4 : 3 * cw // 4]
            else:
                layer0 = const
            self.input.input = th.nn.Parameter(layer0 + th.normal(0, const.std() / 2.0))
            _, _, height, width = self.input.input.shape

            del self.noises
            for i in range(self.num_layers):
                self.noises.register_buffer(f"noise_{i}", th.randn(1, 1, height * 2 ** i, width * 2 ** i))
예제 #4
0
 def _define_model(self):
     self.model = nn.Module()
     self.model.w_1 = simpleCNNGenerator(256, 256)
     self.model.w_2 = simpleCNNGenerator(512, 512)
     self.model.w_3 = simpleCNNGenerator(512, 512)
예제 #5
0
    def __init__(
        self,
        size,
        style_dim,
        n_mlp,
        channel_multiplier=2,
        blur_kernel=[1, 3, 3, 1],
        lr_mlp=0.01,
        label_size=0,
    ):
        super().__init__()

        self.size = size
        self.style_dim = style_dim
        self.label_size = label_size

        layers = [PixelNorm()]

        if self.label_size > 0:
            print(
                "Detected conditional model, initializing LabelConcat layer..")
            #self.label_concat = EqualLinear(self.label_size, style_dim, bias=False, lr_mul=lr_mlp)
            self.label_concat = LabelEmbed(self.label_size, style_dim)

        for i in range(n_mlp):
            if self.label_size > 0 and i == 0:
                input_dim = 2 * style_dim
            else:
                input_dim = style_dim
            layers.append(
                EqualLinear(input_dim,
                            style_dim,
                            lr_mul=lr_mlp,
                            activation='fused_lrelu'))

        self.style = nn.Sequential(*layers)

        self.channels = {
            4: 512,
            8: 512,
            16: 512,
            32: 512,
            64: 256 * channel_multiplier,
            128: 128 * channel_multiplier,
            256: 64 * channel_multiplier,
            512: 32 * channel_multiplier,
            1024: 16 * channel_multiplier,
        }

        layer_index = 0
        self.input = ConstantInput(self.channels[4])
        self.conv1 = StyledConv(self.channels[4],
                                self.channels[4],
                                3,
                                style_dim,
                                blur_kernel=blur_kernel,
                                layerID=layer_index)
        layer_index += 1
        self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)

        self.log_size = int(math.log(size, 2))
        self.num_layers = (self.log_size - 2) * 2 + 1

        self.convs = nn.ModuleList()
        self.upsamples = nn.ModuleList()
        self.to_rgbs = nn.ModuleList()
        self.noises = nn.Module()

        in_channel = self.channels[4]

        for layer_idx in range(self.num_layers):
            res = (layer_idx + 5) // 2
            shape = [1, 1, 2**res, 2**res]
            self.noises.register_buffer(f'noise_{layer_idx}',
                                        torch.randn(*shape))

        for i in range(3, self.log_size + 1):
            out_channel = self.channels[2**i]

            self.convs.append(
                StyledConv(in_channel,
                           out_channel,
                           3,
                           style_dim,
                           upsample=True,
                           blur_kernel=blur_kernel,
                           layerID=0))
            layer_index += 1
            self.convs.append(
                StyledConv(out_channel,
                           out_channel,
                           3,
                           style_dim,
                           blur_kernel=blur_kernel,
                           layerID=layer_index))
            layer_index += 1

            self.to_rgbs.append(ToRGB(out_channel, style_dim))

            in_channel = out_channel

        self.n_latent = self.log_size * 2 - 2
예제 #6
0
def test_non_sequential():
    with pytest.raises(TypeError):
        GPipe(nn.Module(), balance=[1], devices=['cpu'])
예제 #7
0
    def __init__(self,
                 size,
                 style_dim,
                 n_mlp,
                 code0_len=0,
                 code1_len=0,
                 stage0_depth=-1,
                 channel_multiplier=2,
                 blur_kernel=[1, 3, 3, 1],
                 lr_mlp=0.01):
        super().__init__()

        self.size = size
        self.stage0_depth = stage0_depth
        self.style_dim = style_dim
        layers = [PixelNorm()]

        for i in range(n_mlp):
            layers.append(
                EqualLinear(style_dim,
                            style_dim,
                            lr_mul=lr_mlp,
                            activation="fused_lrelu"))

        self.style = nn.Sequential(*layers)

        code_channels = []
        for i in range(0, 9):
            cc = code0_len
            if i > stage0_depth:
                cc = code1_len
            code_channels.append(cc)

        self.channels = {
            4: 512 + code_channels[0],
            8: 512 + code_channels[1],
            16: 512 + code_channels[2],
            32: 512 + code_channels[3],
            64: 256 * channel_multiplier + code_channels[4],
            128: 128 * channel_multiplier + code_channels[5],
            256: 64 * channel_multiplier + code_channels[6],
            512: 32 * channel_multiplier + code_channels[7],
            1024: 16 * channel_multiplier + code_channels[8],
        }

        self.input = ConstantInput(self.channels[4])
        self.conv1 = StyledConv(self.channels[4],
                                self.channels[4],
                                3,
                                style_dim + code_channels[0],
                                blur_kernel=blur_kernel)
        self.to_rgb1 = ToRGB(self.channels[4],
                             style_dim + code_channels[0],
                             upsample=False)

        self.log_size = int(math.log(size, 2))
        self.num_layers = (self.log_size - 2) * 2 + 1

        self.convs = nn.ModuleList()
        self.upsamples = nn.ModuleList()
        self.to_rgbs = nn.ModuleList()
        self.noises = nn.Module()

        in_channel = self.channels[4]

        for layer_idx in range(self.num_layers):
            res = (layer_idx + 5) // 2
            shape = [1, 1, 2**res, 2**res]
            self.noises.register_buffer(f"noise_{layer_idx}",
                                        torch.randn(*shape))

        for i in range(3, self.log_size + 1):
            out_channel = self.channels[2**i]

            self.convs.append(
                StyledConv(
                    in_channel,
                    out_channel,
                    3,
                    style_dim + code_channels[i - 2],
                    upsample=True,
                    blur_kernel=blur_kernel,
                ))

            self.convs.append(
                StyledConv(out_channel,
                           out_channel,
                           3,
                           style_dim + code_channels[i - 2],
                           blur_kernel=blur_kernel))

            self.to_rgbs.append(
                ToRGB(out_channel, style_dim + code_channels[i - 2]))

            in_channel = out_channel

        self.n_latent = self.log_size * 2 - 2
예제 #8
0
def _main():
    """簡易実行テスト用スクリプト."""
    logging.basicConfig(level=logging.INFO)

    # show trainer
    logger.info(VAETrainer(nn.Module()))
    def make_coefficient_params(self, lowres):
        # splat params
        splat = []
        in_channels = self.n_in
        num_downsamples = int(np.log2(min(lowres) / self.spatial_bin))
        extra_convs = max(0, int(np.log2(self.spatial_bin) - np.log2(16)))
        extra_convs = np.linspace(0,
                                  num_downsamples - 1,
                                  extra_convs,
                                  dtype=np.int).tolist()
        for i in range(num_downsamples):
            out_channels = (2**i) * self.feature_multiplier
            splat.append(
                conv(in_channels,
                     out_channels,
                     3,
                     stride=2,
                     norm=False if i == 0 else self.norm))
            if i in extra_convs:
                splat.append(
                    conv(out_channels, out_channels, 3, norm=self.norm))
            in_channels = out_channels
        splat = nn.Sequential(*splat)
        splat_channels = in_channels

        # global params
        global_conv = []
        in_channels = splat_channels
        for _ in range(int(np.log2(self.spatial_bin / 4))):
            global_conv.append(
                conv(in_channels,
                     8 * self.feature_multiplier,
                     3,
                     stride=2,
                     norm=self.norm))
            in_channels = 8 * self.feature_multiplier
        global_conv.append(nn.AdaptiveAvgPool2d(4))
        global_conv = nn.Sequential(*global_conv)
        global_fc = nn.Sequential(
            fc(128 * self.feature_multiplier,
               32 * self.feature_multiplier,
               norm=self.norm),
            fc(32 * self.feature_multiplier,
               16 * self.feature_multiplier,
               norm=self.norm),
            fc(16 * self.feature_multiplier,
               8 * self.feature_multiplier,
               norm=False,
               relu=False))

        # local params
        local = nn.Sequential(
            conv(splat_channels, 8 * self.feature_multiplier, 3),
            conv(8 * self.feature_multiplier,
                 8 * self.feature_multiplier,
                 3,
                 bias=False,
                 norm=False,
                 relu=False))

        # prediction params
        prediction = conv(8 * self.feature_multiplier,
                          self.luma_bins * (self.n_in + 1) * self.n_out,
                          1,
                          norm=False,
                          relu=False)

        coefficient_params = nn.Module()
        coefficient_params.splat = splat
        coefficient_params.global_conv = global_conv
        coefficient_params.global_fc = global_fc
        coefficient_params.local = local
        coefficient_params.prediction = prediction
        return coefficient_params
예제 #10
0
    loss = None
    evaluations = []

    model.eval()
    with torch.no_grad():
        output = model(X)

        if criterion:
            loss = criterion(output, y)

            if isinstance(loss, tuple):
                # for multiple tasks
                total_loss, losses = loss
                loss = total_loss.item(), [loss.item() for loss in losses]
            else:
                loss = loss.item()

    evaluations = [
        eval_func(y.detach().numpy(),
                  output.detach().numpy()) for eval_func in eval_funcs
    ]

    return loss, evaluations


if __name__ == "__main__":
    import inspect
    test = nn.Module()
    print(str(test))
예제 #11
0
    def __init__(self, path: str = None, features: int = 256):
        """Init.

        Arguments:
            path (str, optional): Path to saved model. Defaults to None.
            features (int, optional): Number of features. Defaults to 256.
        """
        super().__init__()

        resnet = models.resnet50(pretrained=False)

        self.pretrained = nn.Module()
        self.scratch = nn.Module()
        self.pretrained.layer1 = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
            resnet.layer1,
        )

        self.pretrained.layer2 = resnet.layer2
        self.pretrained.layer3 = resnet.layer3
        self.pretrained.layer4 = resnet.layer4

        # adjust channel number of feature maps
        self.scratch.layer1_rn = nn.Conv2d(256,
                                           features,
                                           kernel_size=3,
                                           stride=1,
                                           padding=1,
                                           bias=False)
        self.scratch.layer2_rn = nn.Conv2d(512,
                                           features,
                                           kernel_size=3,
                                           stride=1,
                                           padding=1,
                                           bias=False)
        self.scratch.layer3_rn = nn.Conv2d(1024,
                                           features,
                                           kernel_size=3,
                                           stride=1,
                                           padding=1,
                                           bias=False)
        self.scratch.layer4_rn = nn.Conv2d(2048,
                                           features,
                                           kernel_size=3,
                                           stride=1,
                                           padding=1,
                                           bias=False)

        self.scratch.refinenet4 = FeatureFusionBlock(features)
        self.scratch.refinenet3 = FeatureFusionBlock(features)
        self.scratch.refinenet2 = FeatureFusionBlock(features)
        self.scratch.refinenet1 = FeatureFusionBlock(features)

        # adaptive output module: 2 convolutions and upsampling
        self.scratch.output_conv = nn.Sequential(
            nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
            nn.Conv2d(128, 1, kernel_size=3, stride=1, padding=1),
            Interpolate(scale_factor=2, mode="bilinear"),
        )

        # load model
        if path:
            self.load(path)
예제 #12
0
    def __init__(
        self,
        size_h,
        size_w,
        log_size,
        style_dim,
        n_mlp,
        channel_multiplier=2,
        blur_kernel=[1, 3, 3, 1],
        lr_mlp=0.01,
    ):
        super().__init__()

        assert size_w % 2**log_size == 0, f'Width {size_w} is not divisible by {2**log_size}'
        assert size_h % 2**log_size == 0, f'Height {size_h} is not divisible by {2**log_size}'
        self.size_h = size_h
        self.size_w = size_w
        self.log_size = log_size
        
        self.init_h = self.size_h // 2**self.log_size
        self.init_w = self.size_w // 2**self.log_size

        self.style_dim = style_dim

        layers = [PixelNorm()]

        for i in range(n_mlp):
            layers.append(
                EqualLinear(
                    style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"
                )
            )

        self.style = nn.Sequential(*layers)

        self.channels = {
            4: 512,
            8: 512,
            16: 512,
            32: 512,
            64: 256 * channel_multiplier,
            128: 128 * channel_multiplier,
            256: 64 * channel_multiplier,
            512: 32 * channel_multiplier,
            1024: 16 * channel_multiplier,
        }

        self.input = ConstantInput(self.channels[4], size_h=self.init_h, size_w=self.init_w)
        self.conv1 = StyledConv(
            self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
        )
        self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)

        self.num_layers = self.log_size * 2 + 1

        self.convs = nn.ModuleList()
        self.upsamples = nn.ModuleList()
        self.to_rgbs = nn.ModuleList()
        self.noises = nn.Module()

        in_channel = self.channels[4]

        for layer_idx in range(self.num_layers):
            res = (layer_idx + 1) // 2
            shape = [1, 1, self.init_h * 2 ** res, self.init_w * 2 ** res]
            self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape))

        for i in range(3, self.log_size + 3):
            out_channel = self.channels[2 ** i]

            self.convs.append(
                StyledConv(
                    in_channel,
                    out_channel,
                    3,
                    style_dim,
                    upsample=True,
                    blur_kernel=blur_kernel,
                )
            )

            self.convs.append(
                StyledConv(
                    out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
                )
            )

            self.to_rgbs.append(ToRGB(out_channel, style_dim))

            in_channel = out_channel

        self.n_latent = self.log_size * 2 + 2
예제 #13
0
파일: model.py 프로젝트: ysngshn/lgm
 def build_layers(self) -> nn.Module:
     layers = nn.Module()
     for vlayer in self.proto.layers:
         vl = VariableLayer(vlayer)
         layers.add_module(vl.name, vl)
     return layers
예제 #14
0
    def __init__(
        self,
        size,
        style_dim,
        n_mlp,
        channel_multiplier=2,
        blur_kernel=[1, 3, 3, 1],
        lr_mlp=0.01,
        constant_input=False,
        checkpoint=None,
        output_size=None,
        min_rgb_size=4,
        base_res_factor=1,
    ):
        super().__init__()

        self.size = size
        self.style_dim = style_dim

        layers = [PixelNorm()]

        for i in range(n_mlp):
            layers.append(
                EqualLinear(style_dim,
                            style_dim,
                            lr_mul=lr_mlp,
                            activation="fused_lrelu"))

        self.style = nn.Sequential(*layers)

        self.channels = {
            4: 512,
            8: 512,
            16: 512,
            32: 512,
            64: 256 * channel_multiplier,
            128: 128 * channel_multiplier,
            256: 64 * channel_multiplier,
            512: 32 * channel_multiplier,
            1024: 16 * channel_multiplier,
        }

        self.log_size = int(math.log(size, 2))
        self.num_layers = (self.log_size - 2) * 2 + 1
        self.n_latent = self.log_size * 2 - 2
        self.min_rgb_size = min_rgb_size

        if constant_input:
            self.input = ConstantInput(self.channels[4])
        else:
            self.input = LatentInput(style_dim, self.channels[4])

        self.const_manipulation = ManipulationLayer(0)

        layerID = 1
        self.conv1 = StyledConv(self.channels[4],
                                self.channels[4],
                                3,
                                style_dim,
                                blur_kernel=blur_kernel,
                                layerID=layerID)
        self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)

        self.convs = nn.ModuleList()
        self.upsamples = nn.ModuleList()
        self.to_rgbs = nn.ModuleList()
        self.noises = nn.Module()

        in_channel = self.channels[4]

        for layer_idx in range(self.num_layers):
            res = (layer_idx + 5) // 2
            shape = [1, 1, 2**res, 2**res]
            self.noises.register_buffer(f"noise_{layer_idx}", th.randn(*shape))

        for i in range(3, self.log_size + 1):
            out_channel = self.channels[2**i]

            layerID += 1
            self.convs.append(
                StyledConv(in_channel,
                           out_channel,
                           3,
                           style_dim,
                           upsample=True,
                           blur_kernel=blur_kernel,
                           layerID=layerID))

            layerID += 1
            self.convs.append(
                StyledConv(out_channel,
                           out_channel,
                           3,
                           style_dim,
                           blur_kernel=blur_kernel,
                           layerID=layerID))

            self.to_rgbs.append(ToRGB(out_channel, style_dim))

            in_channel = out_channel

        self.truncation_latent = None

        if checkpoint is not None:
            self.load_state_dict(th.load(checkpoint)["g_ema"])

        if size != output_size or base_res_factor != 1:
            for layer_idx in range(self.num_layers):
                res = (layer_idx + 5) // 2
                shape = [
                    1,
                    1,
                    int(base_res_factor * 2**res *
                        (2 if output_size == 1080 else 1)),
                    int(base_res_factor * 2**res *
                        (2 if output_size == 1920 else 1)),
                ]
                setattr(self.noises, f"noise_{layer_idx}", th.randn(*shape))
예제 #15
0
def main(args):

    if args.gpu < 0:
        device = torch.device("cpu")
    else:
        device = torch.device("cuda:" + str(args.gpu))

    # create dataset
    config_file = json.load(open(args.cfg))
    test_dir = config_file['test_dir']
    train_dir = config_file['train_dir']
    dataset = config_file['dataset']
    cache_io = config_file['cache_io']
    napp = config_file['node_app']
    eapp = config_file['edge_app']
    symm_io = config_file['symm_io']
    shuffle_io = config_file['shuffle_io']
    n_classes = config_file['num_classes']
    apply_da = config_file['data_augment']
    rad_scale = config_file['rad_scale']
    angle_scale = config_file['angle_scale']
    length_scale = config_file['length_scale']
    curve_scale = config_file['curve_scale']
    poly_scale = config_file['poly_scale']
    batch_io = args.batch_size
    epochs = args.epochs
    bdir = os.path.basename(train_dir)

    input_dim = 58
    if napp:
        input_dim = input_dim + 21

    if eapp:
        input_dim = input_dim + 9

    norm_factors = {
        'rad_scale': rad_scale,
        'angle_scale': angle_scale,
        'length_scale': length_scale,
        'curve_scale': curve_scale,
        'poly_scale': poly_scale
    }

    prefix = 'data-' + str(bdir) + ':' + 'mign' + '_m-tag_ni-' + str(
        input_dim) + '_nh-' + str(args.n_hidden) + '_lay-' + str(
            args.n_layers) + '_hops-' + str(args.hops) + '_napp-' + str(
                napp) + '_eapp-' + str(eapp) + '_do-' + str(
                    args.dropout) + '_ro-' + str(args.readout)

    if args.readout == 'spp':
        extra = '_ng-' + str(args.n_grid)
        prefix += extra

    extra = '_b-' + str(batch_io)
    prefix += extra

    print('saving to prefix: ', prefix)

    # create train and test dataset
    # create train dataset
    fsl_dataset = ShockGraphDataset(test_dir,
                                    dataset,
                                    norm_factors,
                                    n_shot=args.n_shot,
                                    k_way=args.k_way,
                                    episodes=args.episodes,
                                    test_samples=args.samples,
                                    node_app=napp,
                                    edge_app=eapp,
                                    cache=True,
                                    symmetric=symm_io,
                                    data_augment=False,
                                    grid=args.n_grid)

    model_files = glob.glob(prefix + '*pth')
    model_files.sort()

    numb_train = args.n_shot * args.k_way

    for state_path in model_files:
        print('Using weights: ', state_path)

        model = Classifier(input_dim, args.n_hidden, n_classes, args.n_layers,
                           args.ctype, args.hops, args.readout, F.relu,
                           args.dropout, args.n_grid, args.K, device)

        layer = nn.Module()
        for name, module in model.named_children():
            if args.nbnn:
                if name == 'layers':
                    layer = module[-1]
            else:
                if name == 'readout_fcn':
                    layer = module

        model.load_state_dict(torch.load(state_path)['model_state_dict'])
        model.to(device)
        model.eval()

        class_accuracy = np.zeros(args.episodes)
        for idx in tqdm(range(args.episodes)):

            bg, label = fsl_dataset[idx]

            if args.nbnn:
                embeddings = im2set(bg, layer, model, args.n_hidden)
            else:
                embeddings = im2vec(bg, layer, model, args.n_hidden)

            if args.nbnn:
                support_exemplars = np.sum(bg.batch_num_nodes[:numb_train])
            else:
                support_exemplars = numb_train

            train_embeddings = embeddings[:support_exemplars, :]
            if args.nbnn:
                train_labels = np.repeat(label[:numb_train],
                                         bg.batch_num_nodes[:numb_train])
            else:
                train_labels = label[:numb_train]

            test_embeddings = embeddings[support_exemplars:, :]
            test_labels = label[numb_train:]

            D = all_pairwise_distance(test_embeddings, train_embeddings,
                                      args.dist)

            if args.nbnn:
                predicted = predict_nbnn(D, train_labels,
                                         bg.batch_num_nodes[numb_train:],
                                         test_labels.shape[0])
            else:
                predicted = predict(D, train_labels)

            groundtruth = test_labels

            gg = torch.sum(predicted == groundtruth) / float(len(groundtruth))
            class_accuracy[idx] = gg

        print('Class Accuracy:{:4f}%'.format(np.mean(class_accuracy) * 100.0))
        del model
    def __init__(self,
                 out_size,
                 num_style_feat=512,
                 num_mlp=8,
                 channel_multiplier=2,
                 resample_kernel=(1, 3, 3, 1),
                 lr_mlp=0.01,
                 narrow=1):
        super(StyleGAN2Generator, self).__init__()
        # Style MLP layers
        self.num_style_feat = num_style_feat
        style_mlp_layers = [NormStyleCode()]
        for i in range(num_mlp):
            style_mlp_layers.append(
                EqualLinear(
                    num_style_feat,
                    num_style_feat,
                    bias=True,
                    bias_init_val=0,
                    lr_mul=lr_mlp,
                    activation='fused_lrelu'))
        self.style_mlp = nn.Sequential(*style_mlp_layers)

        channels = {
            '4': int(512 * narrow),
            '8': int(512 * narrow),
            '16': int(512 * narrow),
            '32': int(512 * narrow),
            '64': int(256 * channel_multiplier * narrow),
            '128': int(128 * channel_multiplier * narrow),
            '256': int(64 * channel_multiplier * narrow),
            '512': int(32 * channel_multiplier * narrow),
            '1024': int(16 * channel_multiplier * narrow)
        }
        self.channels = channels

        self.constant_input = ConstantInput(channels['4'], size=4)
        self.style_conv1 = StyleConv(
            channels['4'],
            channels['4'],
            kernel_size=3,
            num_style_feat=num_style_feat,
            demodulate=True,
            sample_mode=None,
            resample_kernel=resample_kernel)
        self.to_rgb1 = ToRGB(
            channels['4'],
            num_style_feat,
            upsample=False,
            resample_kernel=resample_kernel)

        self.log_size = int(math.log(out_size, 2))
        self.num_layers = (self.log_size - 2) * 2 + 1
        self.num_latent = self.log_size * 2 - 2

        self.style_convs = nn.ModuleList()
        self.to_rgbs = nn.ModuleList()
        self.noises = nn.Module()

        in_channels = channels['4']
        # noise
        for layer_idx in range(self.num_layers):
            resolution = 2**((layer_idx + 5) // 2)
            shape = [1, 1, resolution, resolution]
            self.noises.register_buffer(f'noise{layer_idx}',
                                        torch.randn(*shape))
        # style convs and to_rgbs
        for i in range(3, self.log_size + 1):
            out_channels = channels[f'{2**i}']
            self.style_convs.append(
                StyleConv(
                    in_channels,
                    out_channels,
                    kernel_size=3,
                    num_style_feat=num_style_feat,
                    demodulate=True,
                    sample_mode='upsample',
                    resample_kernel=resample_kernel,
                ))
            self.style_convs.append(
                StyleConv(
                    out_channels,
                    out_channels,
                    kernel_size=3,
                    num_style_feat=num_style_feat,
                    demodulate=True,
                    sample_mode=None,
                    resample_kernel=resample_kernel))
            self.to_rgbs.append(
                ToRGB(
                    out_channels,
                    num_style_feat,
                    upsample=True,
                    resample_kernel=resample_kernel))
            in_channels = out_channels
예제 #17
0
def test_verify_module_non_sequential(setup_rpc):
    with pytest.raises(TypeError,
                       match="module must be nn.Sequential to be partitioned"):
        Pipe(nn.Module())
예제 #18
0
    def __init__(
        self,
        size,
        style_dim,
        args,
        channel_multiplier=2,
        blur_kernel=[1, 3, 3, 1],
        lr_mlp=0.01,
    ):
        super().__init__()

        self.size = size
        self.style_dim = style_dim

        self.channels = {
            4: 512,
            8: 512,
            16: 512,
            32: 512,
            64: 256 * channel_multiplier,
            128: 128 * channel_multiplier,
            256: 64 * channel_multiplier,
            512: 32 * channel_multiplier,
            1024: 16 * channel_multiplier,
        }

        self.input = ConstantInput(style_dim, size=4)

        conv_Trans = []
        for medium_layer_idx in range(args.translayer):
            conv_Trans.append(ConvLayer(args.latent, args.latent, 1))
        self.conv_Trans = nn.Sequential(*conv_Trans)

        self.conv1 = StyledConv(style_dim,
                                self.channels[4],
                                3,
                                style_dim,
                                blur_kernel=blur_kernel)
        self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)

        self.log_size = int(math.log(size, 2))
        self.num_layers = (self.log_size - 2) * 2 + 1
        self.convs = nn.ModuleList()
        self.upsamples = nn.ModuleList()
        self.to_rgbs = nn.ModuleList()
        self.noises = nn.Module()

        in_channel = self.channels[4]

        for layer_idx in range(self.num_layers):
            res = (layer_idx + 5) // 2
            shape = [1, 1, 2**res, 2**res]
            self.noises.register_buffer(f'noise_{layer_idx}',
                                        torch.randn(*shape))

        for i in range(3, self.log_size + 1):
            out_channel = self.channels[2**i]

            self.convs.append(
                StyledConv(
                    in_channel,
                    out_channel,
                    3,
                    style_dim,
                    upsample=True,
                    blur_kernel=blur_kernel,
                ))

            self.convs.append(
                StyledConv(out_channel,
                           out_channel,
                           3,
                           style_dim,
                           blur_kernel=blur_kernel))

            self.to_rgbs.append(ToRGB(out_channel, style_dim))

            in_channel = out_channel

        self.n_latent = self.log_size * 2 - 2
예제 #19
0
 def get_features(self, original_model):
     features = nn.Module()
     for name, module in list(original_model.named_children())[:-3]:
         features.add_module(name, module)
     return features
예제 #20
0
 def _reconstruct_densenet(self, basemodel):
     model = nn.Module()
     return model
예제 #21
0
        encoder_dim = 512
        encoder = models.vgg16(pretrained=pretrained)
        # capture only feature part and remove last relu and maxpool
        layers = list(encoder.features.children())[:-2]

        if pretrained:
            # if using pretrained then only train conv5_1, conv5_2, and conv5_3
            for l in layers[:-5]: 
                for p in l.parameters():
                    p.requires_grad = False

    if opt.mode.lower() == 'cluster' and not opt.vladv2:
        layers.append(L2Norm())

    encoder = nn.Sequential(*layers)
    model = nn.Module() 
    model.add_module('encoder', encoder)

    if opt.mode.lower() != 'cluster':
        if opt.pooling.lower() == 'netvlad':
            net_vlad = netvlad.NetVLAD(num_clusters=opt.num_clusters, dim=encoder_dim, vladv2=opt.vladv2)
            if not opt.resume: 
                if opt.mode.lower() == 'train':
                    initcache = join(opt.dataPath, 'centroids', opt.arch + '_' + train_set.dataset + '_' + str(opt.num_clusters) +'_desc_cen.hdf5')
                else:
                    initcache = join(opt.dataPath, 'centroids', opt.arch + '_' + whole_test_set.dataset + '_' + str(opt.num_clusters) +'_desc_cen.hdf5')

                if not exists(initcache):
                    raise FileNotFoundError('Could not find clusters, please run with --mode=cluster before proceeding')

                with h5py.File(initcache, mode='r') as h5: 
예제 #22
0
    def __init__(
        self,
        resolution=1024,
        w_space_dim=512,
        n_mlp=8,
        channel_multiplier=2,
        blur_kernel=[1, 3, 3, 1],
        lr_mlp=0.01,
    ):
        super().__init__()

        self.resolution = resolution

        self.w_space_dim = w_space_dim
        self.n_mlp = n_mlp
        self.channel_multiplier = channel_multiplier

        self.log_size = int(math.log(resolution, 2))
        self.num_layers = (self.log_size - 2) * 2 + 1
        self.num_latents = self.log_size * 2 - 2

        self.style = StyleGAN2Transformer(w_space_dim, n_mlp, lr_mlp).style

        self.channels = {
            4: 512,
            8: 512,
            16: 512,
            32: 512,
            64: 256 * channel_multiplier,
            128: 128 * channel_multiplier,
            256: 64 * channel_multiplier,
            512: 32 * channel_multiplier,
            1024: 16 * channel_multiplier,
        }

        self.input = ConstantInput(self.channels[4])
        self.conv1 = StyledConv(
            self.channels[4], self.channels[4], 3, w_space_dim, blur_kernel=blur_kernel)
        self.to_rgb1 = ToRGB(self.channels[4], w_space_dim)

        self.convs = nn.ModuleList()
        self.to_rgbs = nn.ModuleList()
        self.noises = nn.Module()

        in_channel = self.channels[4]

        for layer_idx in range(self.num_layers):
            res = (layer_idx + 5) // 2
            shape = [1, 1, 2 ** res, 2 ** res]
            self.noises.register_buffer(
                f"noise_{layer_idx}", torch.randn(*shape))

        for i in range(3, self.log_size + 1):
            out_channel = self.channels[2 ** i]

            self.convs.append(
                StyledConvWithUpsample(in_channel, out_channel, 3, w_space_dim, blur_kernel=blur_kernel)
            )

            self.convs.append(
                StyledConv(out_channel, out_channel, 3, w_space_dim, blur_kernel=blur_kernel)
            )
            self.to_rgbs.append(ToRGBWithUpsample(out_channel, w_space_dim))

            in_channel = out_channel
예제 #23
0
파일: main.py 프로젝트: MalteEbner/NNCLR
 def __init__(self, dataloader_kNN):
     super().__init__()
     self.backbone = nn.Module()
     self.max_accuracy = 0.0
     self.dataloader_kNN = dataloader_kNN
예제 #24
0
 def test_repr_smoke(self):
     image = data.LocalImage("image", transform=nn.Module(), note="note")
     assert isinstance(repr(image), str)
예제 #25
0
    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
                 attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
                 resolution, z_channels, double_z=True, **ignore_kwargs):
        super().__init__()
        self.ch = ch
        self.temb_ch = 0
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        self.resolution = resolution
        self.in_channels = in_channels

        # downsampling
        self.conv_in = torch.nn.Conv2d(in_channels,
                                       self.ch,
                                       kernel_size=3,
                                       stride=1,
                                       padding=1)

        curr_res = resolution
        in_ch_mult = (1,)+tuple(ch_mult)
        self.down = nn.ModuleList()
        for i_level in range(self.num_resolutions):
            block = nn.ModuleList()
            attn = nn.ModuleList()
            block_in = ch*in_ch_mult[i_level]
            block_out = ch*ch_mult[i_level]
            for i_block in range(self.num_res_blocks):
                block.append(ResnetBlock(in_channels=block_in,
                                         out_channels=block_out,
                                         temb_channels=self.temb_ch,
                                         dropout=dropout))
                block_in = block_out
                if curr_res in attn_resolutions:
                    attn.append(AttnBlock(block_in))
            down = nn.Module()
            down.block = block
            down.attn = attn
            if i_level != self.num_resolutions-1:
                down.downsample = Downsample(block_in, resamp_with_conv)
                curr_res = curr_res // 2
            self.down.append(down)

        # middle
        self.mid = nn.Module()
        self.mid.block_1 = ResnetBlock(in_channels=block_in,
                                       out_channels=block_in,
                                       temb_channels=self.temb_ch,
                                       dropout=dropout)
        self.mid.attn_1 = AttnBlock(block_in)
        self.mid.block_2 = ResnetBlock(in_channels=block_in,
                                       out_channels=block_in,
                                       temb_channels=self.temb_ch,
                                       dropout=dropout)

        # end
        self.norm_out = Normalize(block_in)
        self.conv_out = torch.nn.Conv2d(block_in,
                                        2*z_channels if double_z else z_channels,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)
예제 #26
0
    def __init__(self,
                 classes,
                 rel_classes,
                 mode='sgdet',
                 num_gpus=1,
                 require_overlap_det=True,
                 embed_dim=200,
                 hidden_dim=4096,
                 use_resnet=False,
                 thresh=0.01,
                 use_proposals=False,
                 depth_model=None,
                 pretrained_depth=False,
                 active_features=None,
                 frozen_features=None,
                 use_embed=False,
                 **kwargs):
        """
        :param classes: object classes
        :param rel_classes: relationship classes. None if were not using rel mode
        :param mode: (sgcls, predcls, or sgdet)
        :param num_gpus: how many GPUS 2 use
        :param require_overlap_det: whether two objects must intersect
        :param embed_dim: word2vec embeddings dimension
        :param hidden_dim: dimension of the fusion hidden layer
        :param use_resnet: use resnet as faster-rcnn's backbone
        :param thresh: faster-rcnn related threshold (Threshold for calling it a good box)
        :param use_proposals: whether to use region proposal candidates
        :param depth_model: provided architecture for depth feature extraction
        :param pretrained_depth: whether the depth feature extractor should be initialized with ImageNet weights
        :param active_features: what set of features should be enabled (e.g. 'vdl' : visual, depth, and location features)
        :param frozen_features: what set of features should be frozen (e.g. 'd' : depth)
        :param use_embed: use word2vec embeddings
        """
        RelModelBase.__init__(self, classes, rel_classes, mode, num_gpus,
                              require_overlap_det, active_features,
                              frozen_features)
        self.pooling_size = 7
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.obj_dim = 2048 if use_resnet else 4096

        # -- Store depth related parameters
        assert depth_model in DEPTH_MODELS
        self.depth_model = depth_model
        self.pretrained_depth = pretrained_depth
        self.depth_pooling_dim = DEPTH_DIMS[self.depth_model]
        self.use_embed = use_embed
        self.detector = nn.Module()
        features_size = 0

        # -- Check whether ResNet is selected as faster-rcnn's backbone
        if use_resnet:
            raise ValueError(
                "The current model does not support ResNet as the Faster-RCNN's backbone."
            )
        """ *** DIFFERENT COMPONENTS OF THE PROPOSED ARCHITECTURE *** 
        This is the part where the different components of the proposed relation detection 
        architecture are defined. In the case of RGB images, we have class probability distribution
        features, visual features, and the location ones. If we are considering depth images as well,
        we augment depth features too. """

        # -- Visual features
        if self.has_visual:
            # -- Define faster R-CNN network and it's related feature extractors
            self.detector = ObjectDetector(
                classes=classes,
                mode=('proposals' if use_proposals else 'refinerels')
                if mode == 'sgdet' else 'gtbox',
                use_resnet=use_resnet,
                thresh=thresh,
                max_per_img=64,
            )
            self.roi_fmap_obj = load_vgg(pretrained=False).classifier

            # -- Define visual features hidden layer
            self.visual_hlayer = nn.Sequential(*[
                xavier_init(nn.Linear(self.obj_dim * 2, self.FC_SIZE_VISUAL)),
                nn.ReLU(inplace=True),
                nn.Dropout(0.8)
            ])
            self.visual_scale = ScaleLayer(1.0)
            features_size += self.FC_SIZE_VISUAL

        # -- Location features
        if self.has_loc:
            # -- Define location features hidden layer
            self.location_hlayer = nn.Sequential(*[
                xavier_init(nn.Linear(self.LOC_INPUT_SIZE, self.FC_SIZE_LOC)),
                nn.ReLU(inplace=True),
                nn.Dropout(0.1)
            ])
            self.location_scale = ScaleLayer(1.0)
            features_size += self.FC_SIZE_LOC

        # -- Class features
        if self.has_class:
            if self.use_embed:
                # -- Define class embeddings
                embed_vecs = obj_edge_vectors(self.classes,
                                              wv_dim=self.embed_dim)
                self.obj_embed = nn.Embedding(self.num_classes, self.embed_dim)
                self.obj_embed.weight.data = embed_vecs.clone()

            classme_input_dim = self.embed_dim if self.use_embed else self.num_classes
            # -- Define Class features hidden layer
            self.classme_hlayer = nn.Sequential(*[
                xavier_init(
                    nn.Linear(classme_input_dim * 2, self.FC_SIZE_CLASS)),
                nn.ReLU(inplace=True),
                nn.Dropout(0.1)
            ])
            self.classme_scale = ScaleLayer(1.0)
            features_size += self.FC_SIZE_CLASS

        # -- Depth features
        if self.has_depth:
            # -- Initialize depth backbone
            self.depth_backbone = DepthCNN(depth_model=self.depth_model,
                                           pretrained=self.pretrained_depth)

            # -- Create a relation head which is used to carry on the feature extraction
            # from RoIs of depth features
            self.depth_rel_head = self.depth_backbone.get_classifier()

            # -- Define depth features hidden layer
            self.depth_rel_hlayer = nn.Sequential(*[
                xavier_init(
                    nn.Linear(self.depth_pooling_dim * 2, self.FC_SIZE_DEPTH)),
                nn.ReLU(inplace=True),
                nn.Dropout(0.6),
            ])
            self.depth_scale = ScaleLayer(1.0)
            features_size += self.FC_SIZE_DEPTH

        # -- *** Fusion layer *** --
        # -- A hidden layer for concatenated features (fusion features)
        self.fusion_hlayer = nn.Sequential(*[
            xavier_init(nn.Linear(features_size, self.hidden_dim)),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1)
        ])

        # -- Final FC layer which predicts the relations
        self.rel_out = xavier_init(
            nn.Linear(self.hidden_dim, self.num_rels, bias=True))

        # -- Freeze the user specified features
        if self.frz_visual:
            self.freeze_module(self.detector)
            self.freeze_module(self.roi_fmap_obj)
            self.freeze_module(self.visual_hlayer)

        if self.frz_class:
            self.freeze_module(self.classme_hlayer)

        if self.frz_loc:
            self.freeze_module(self.location_hlayer)

        if self.frz_depth:
            self.freeze_module(self.depth_backbone)
            self.freeze_module(self.depth_rel_head)
            self.freeze_module(self.depth_rel_hlayer)
예제 #27
0
    def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
                 attn_resolutions, dropout=0.0, resamp_with_conv=True,
                 in_channels, c_channels,
                 resolution, z_channels, use_timestep=False, **ignore_kwargs):
        super().__init__()
        self.ch = ch
        self.temb_ch = self.ch*4
        self.num_resolutions = len(ch_mult)
        self.num_res_blocks = num_res_blocks
        self.resolution = resolution

        self.use_timestep = use_timestep
        if self.use_timestep:
            # timestep embedding
            self.temb = nn.Module()
            self.temb.dense = nn.ModuleList([
                torch.nn.Linear(self.ch,
                                self.temb_ch),
                torch.nn.Linear(self.temb_ch,
                                self.temb_ch),
            ])

        # downsampling
        self.conv_in = torch.nn.Conv2d(c_channels,
                                       self.ch,
                                       kernel_size=3,
                                       stride=1,
                                       padding=1)

        curr_res = resolution
        in_ch_mult = (1,)+tuple(ch_mult)
        self.down = nn.ModuleList()
        for i_level in range(self.num_resolutions):
            block = nn.ModuleList()
            attn = nn.ModuleList()
            block_in = ch*in_ch_mult[i_level]
            block_out = ch*ch_mult[i_level]
            for i_block in range(self.num_res_blocks):
                block.append(ResnetBlock(in_channels=block_in,
                                         out_channels=block_out,
                                         temb_channels=self.temb_ch,
                                         dropout=dropout))
                block_in = block_out
                if curr_res in attn_resolutions:
                    attn.append(AttnBlock(block_in))
            down = nn.Module()
            down.block = block
            down.attn = attn
            if i_level != self.num_resolutions-1:
                down.downsample = Downsample(block_in, resamp_with_conv)
                curr_res = curr_res // 2
            self.down.append(down)

        self.z_in = torch.nn.Conv2d(z_channels,
                                    block_in,
                                    kernel_size=1,
                                    stride=1,
                                    padding=0)
        # middle
        self.mid = nn.Module()
        self.mid.block_1 = ResnetBlock(in_channels=2*block_in,
                                       out_channels=block_in,
                                       temb_channels=self.temb_ch,
                                       dropout=dropout)
        self.mid.attn_1 = AttnBlock(block_in)
        self.mid.block_2 = ResnetBlock(in_channels=block_in,
                                       out_channels=block_in,
                                       temb_channels=self.temb_ch,
                                       dropout=dropout)

        # upsampling
        self.up = nn.ModuleList()
        for i_level in reversed(range(self.num_resolutions)):
            block = nn.ModuleList()
            attn = nn.ModuleList()
            block_out = ch*ch_mult[i_level]
            skip_in = ch*ch_mult[i_level]
            for i_block in range(self.num_res_blocks+1):
                if i_block == self.num_res_blocks:
                    skip_in = ch*in_ch_mult[i_level]
                block.append(ResnetBlock(in_channels=block_in+skip_in,
                                         out_channels=block_out,
                                         temb_channels=self.temb_ch,
                                         dropout=dropout))
                block_in = block_out
                if curr_res in attn_resolutions:
                    attn.append(AttnBlock(block_in))
            up = nn.Module()
            up.block = block
            up.attn = attn
            if i_level != 0:
                up.upsample = Upsample(block_in, resamp_with_conv)
                curr_res = curr_res * 2
            self.up.insert(0, up) # prepend to get consistent order

        # end
        self.norm_out = Normalize(block_in)
        self.conv_out = torch.nn.Conv2d(block_in,
                                        out_ch,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)
예제 #28
0
def test_proper_refcount():
    torch_module = nn.Module()
    lightning_module = LightningModule()

    assert sys.getrefcount(torch_module) == sys.getrefcount(lightning_module)
    def __init__(
        self,
        size,
        style_dim,
        n_mlp,
        channel_multiplier=2,
        blur_kernel=[1, 3, 3, 1],
        lr_mlp=0.01,
    ):
        super().__init__()

        self.size = size

        self.style_dim = style_dim

        layers = [PixelNorm()]

        for i in range(n_mlp):
            layers.append(
                EqualLinear(
                    style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu'
                )
            )

        self.style = nn.Sequential(*layers)

        self.channels = {
            4: 512,
            8: 512,
            16: 512,
            32: 512,
            64: 256 * channel_multiplier,
            128: 128 * channel_multiplier,
            256: 64 * channel_multiplier,
            512: 32 * channel_multiplier,
            1024: 16 * channel_multiplier,
        }

        self.input = ConstantInput(self.channels[4])
        self.conv1 = StyledConv(
            self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
        )
        self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)

        self.log_size = int(math.log(size, 2))
        self.num_layers = (self.log_size - 2) * 2 + 1

        self.convs = nn.ModuleList()
        self.upsamples = nn.ModuleList()
        self.to_rgbs = nn.ModuleList()
        self.noises = nn.Module()

        in_channel = self.channels[4]

        for layer_idx in range(self.num_layers):
            res = (layer_idx + 5) // 2
            shape = [1, 1, 2 ** res, 2 ** res]
            self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape))

        for i in range(3, self.log_size + 1):
            out_channel = self.channels[2 ** i]

            self.convs.append(
                StyledConv(
                    in_channel,
                    out_channel,
                    3,
                    style_dim,
                    upsample=True,
                    blur_kernel=blur_kernel,
                )
            )

            self.convs.append(
                StyledConv(
                    out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
                )
            )

            self.to_rgbs.append(ToRGB(out_channel, style_dim))

            in_channel = out_channel

        self.n_latent = self.log_size * 2 - 2

        self.transition = 0
예제 #30
0
 def __init__(self, base_model):
     super(PTModelWrapper, self).__init__(base_model)
     self.H = nn.Module()
     self.H.register_parameter('threshold', nn.Parameter(torch.Tensor(
         [0.5])))  # initialize to prob=0.5 for faster convergence.