Beispiel #1
0
    def __init__(self,
                 ni,
                 nf,
                 ks,
                 pooler=nn.AvgPool2d(kernel_size=2, stride=2),
                 init='He',
                 nl=nn.ReLU(),
                 equalized_lr=False,
                 blur_type=None):
        super(FastResBlock2dDownsample, self).__init__()

        padding = (ks - 1) // 2

        self.conv_layer_1 = nn.Sequential(
            Conv2dEx(ni,
                     nf,
                     ks=ks,
                     stride=1,
                     padding=padding,
                     init='he',
                     equalized_lr=equalized_lr), nl)

        self.conv_layer_2 = nn.Sequential()
        self.skip_connection = nn.Sequential()

        _seq_n = 0
        if blur_type is not None:
            blur_op = get_blur_op(blur_type=blur_type, num_channels=nf)

            self.conv_layer_2.add_module(str(_seq_n), blur_op)
            self.skip_connection.add_module(str(_seq_n), blur_op)
            _seq_n += 1

        self.conv_layer_2.add_module(
            str(_seq_n),
            Conv2dEx(nf,
                     nf,
                     ks=ks,
                     stride=1,
                     padding=padding,
                     init='he',
                     equalized_lr=equalized_lr))
        self.skip_connection.add_module(str(_seq_n), pooler)
        _seq_n += 1

        self.conv_layer_2.add_module(str(_seq_n), pooler)
        self.skip_connection.add_module(
            str(_seq_n),
            Conv2dEx(ni,
                     nf,
                     ks=1,
                     stride=1,
                     padding=0,
                     init='xavier',
                     equalized_lr=equalized_lr))
Beispiel #2
0
    def increase_scale(self):
        """Use this to increase scale during training or for initial resolution."""
        # update metadata
        if not self.scale_inc_metadata_updated:
            super(self.__class__, self).increase_scale()
        else:
            self.scale_inc_metadata_updated = False

        blur_op = get_blur_op( blur_type = self.gen_blur_type, num_channels = self.fmap ) if \
                  self.gen_blur_type is not None else None

        self.gen_layers.append(
            self.get_conv_layer(ni=self.fmap_prev,
                                upsample=True,
                                blur_op=blur_op))
        self.gen_layers.append(self.get_conv_layer(ni=self.fmap))

        self.prev_torgb = copy.deepcopy(self.torgb)
        self._update_torgb(ni=self.fmap)
Beispiel #3
0
    def increase_scale(self):
        """Use this to increase scale during training or for initial resolution."""
        # update metadata
        if not self.scale_inc_metadata_updated:
            super(self.__class__, self).increase_scale()
        else:
            self.scale_inc_metadata_updated = False

        self.preprocess_x = Lambda(lambda x: x.view(
            -1, FMAP_SAMPLES + self.num_classes, self.curr_res, self.curr_res))

        self.prev_fromrgb = copy.deepcopy(self.fromrgb)
        self._update_fromrgb(nf=self.fmap)

        blur_op = get_blur_op( blur_type = self.disc_blur_type, num_channels = self.fmap ) if \
                  self.disc_blur_type is not None else None

        self.disc_blocks.insert(
            0,
            nn.Sequential(
                self.get_conv_layer(nf=self.fmap),
                self.get_conv_layer(nf=self.fmap_prev,
                                    downsample=True,
                                    blur_op=blur_op)))
Beispiel #4
0
    def __init__(self,
                 ni,
                 nf,
                 ks,
                 norm_type,
                 upsampler=None,
                 pooler=None,
                 init='He',
                 nl=nn.ReLU(),
                 res=None,
                 flip_sampling=False,
                 equalized_lr=False,
                 blur_type=None):
        super(ResBlock2d, self).__init__()

        assert not (upsampler is not None and pooler is not None)

        padding = (ks - 1) // 2  # 'SAME' padding for stride 1 conv

        if not flip_sampling:
            self.nif = nf if (upsampler is not None and pooler is None) else ni
        else:
            self.nif = ni if (upsampler is None and pooler is not None) else nf
        self.convs = (
            Conv2dEx(ni,
                     self.nif,
                     ks=ks,
                     stride=1,
                     padding=padding,
                     init=init,
                     equalized_lr=equalized_lr),
            Conv2dEx(self.nif,
                     nf,
                     ks=ks,
                     stride=1,
                     padding=padding,
                     init=init,
                     equalized_lr=equalized_lr),
            Conv2dEx(ni,
                     nf,
                     ks=1,
                     stride=1,
                     padding=0,
                     init='Xavier',
                     equalized_lr=equalized_lr),  # this is same as a FC layer
        )

        blur_op = get_blur_op(
            blur_type=blur_type,
            num_channels=self.convs[0].nf) if blur_type is not None else None

        _norm_nls = (
            [NormalizeLayer(norm_type, ni=ni, res=res), nl],
            [NormalizeLayer(norm_type, ni=self.convs[0].nf, res=res), nl],
        )

        if upsampler is not None:
            _mostly_linear_op_1 = [
                upsampler, self.convs[0], blur_op
            ] if blur_type is not None else [upsampler, self.convs[0]]
            _mostly_linear_op_2 = [
                upsampler, self.convs[2], blur_op
            ] if blur_type is not None else [upsampler, self.convs[2]]
            _ops = (
                _mostly_linear_op_1,
                [self.convs[1]],
                _mostly_linear_op_2,
            )
        elif pooler is not None:
            _mostly_linear_op_1 = [
                blur_op, self.convs[1], pooler
            ] if blur_type is not None else [self.convs[1], pooler]
            _mostly_linear_op_2 = [
                blur_op, pooler, self.convs[2]
            ] if blur_type is not None else [pooler, self.convs[2]]
            _ops = (
                [self.convs[0]],
                _mostly_linear_op_1,
                _mostly_linear_op_2,
            )
        else:
            _ops = (
                [self.convs[0]],
                [self.convs[1]],
                [self.convs[2]],
            )

        self.conv_layer_1 = nn.Sequential(*(_norm_nls[0] + _ops[0]))
        self.conv_layer_2 = nn.Sequential(*(_norm_nls[1] + _ops[1]))

        if (upsampler is not None or pooler is not None) or ni != nf:
            self.skip_connection = nn.Sequential(*(_ops[2]))
        else:
            self.skip_connection = Lambda(lambda x: x)