예제 #1
0
    def __init__(self, label, image_size, channel_num, kernel_num, z_size,
                 device):
        super().__init__()
        self.model_name = "ae_vine"
        self.label = label
        self.image_size = image_size
        self.channel_num = channel_num
        self.kernel_num = kernel_num
        self.z_size = z_size
        self.device = device
        self.vine = None

        # encoder
        self.encoder = nn.Sequential(
            _conv(channel_num, kernel_num // 4),
            _conv(kernel_num // 4, kernel_num // 2),
            _conv(kernel_num // 2, kernel_num),
        )

        # encoded feature's size and volume
        self.feature_size = image_size // 8
        self.feature_volume = kernel_num * (self.feature_size**2)

        # decoder
        self.decoder = nn.Sequential(_deconv(kernel_num, kernel_num // 2),
                                     _deconv(kernel_num // 2, kernel_num // 4),
                                     _deconv(kernel_num // 4, channel_num),
                                     nn.Sigmoid())

        # projection
        self.project = _linear(z_size, self.feature_volume, relu=False)
        self.q_layer = _linear(self.feature_volume, z_size, relu=False)
예제 #2
0
    def __init__(self, label, image_size, channel_num, kernel_num, z_size, device):
        # configurations
        super().__init__()
        self.model_name = "cvae"
        self.label = label
        self.image_size = image_size
        self.channel_num = channel_num
        self.kernel_num = kernel_num
        self.z_size = z_size
        self.device = device

        # encoder
        self.encoder = nn.Sequential(
            _conv(channel_num, kernel_num // 4),
            _conv(kernel_num // 4, kernel_num // 2),
            _conv(kernel_num // 2, kernel_num),
        )

        # encoded feature's size and volume
        self.feature_size = image_size // 8
        self.feature_volume = kernel_num * (self.feature_size ** 2)

        # q
        self.q_mean = _linear(self.feature_volume, z_size, relu=False)
        self.q_logvar = _linear(self.feature_volume, z_size, relu=False)
        n = int(self.z_size * (self.z_size - 1) / 2)
        self.q_atanhcor = _linear(self.feature_volume, n, relu=False)

        # projection
        self.project = _linear(z_size, self.feature_volume, relu=False)

        # decoder
        self.decoder = nn.Sequential(
            _deconv(kernel_num, kernel_num // 2),
            _deconv(kernel_num // 2, kernel_num // 4),
            _deconv(kernel_num // 4, channel_num),
            nn.Sigmoid()
        )