コード例 #1
0
class Ladder(torch.nn.Module):
    def __init__(self, encoder_sizes, decoder_sizes, encoder_activations,
                 encoder_train_bn_scaling, noise_std, use_cuda):
        super(Ladder, self).__init__()
        self.use_cuda = use_cuda
        decoder_in = encoder_sizes[-1]
        encoder_in = decoder_sizes[-1]
        self.se = StackedEncoders(encoder_in, encoder_sizes, encoder_activations,
                                  encoder_train_bn_scaling, noise_std, use_cuda)
        self.de = StackedDecoders(decoder_in, decoder_sizes, encoder_in, use_cuda)
        self.bn_image = torch.nn.BatchNorm1d(encoder_in, affine=False)

    def forward_encoders_clean(self, data): 
        return self.se.forward_clean(data)

    def forward_encoders_noise(self, data):
        return self.se.forward_noise(data)

    def forward_decoders(self, tilde_z_layers, encoder_output, tilde_z_bottom):
        return self.de.forward(tilde_z_layers, encoder_output, tilde_z_bottom)

    def get_encoders_tilde_z(self, reverse=True):
        return self.se.get_encoders_tilde_z(reverse)

    def get_encoders_z_pre(self, reverse=True):
        return self.se.get_encoders_z_pre(reverse)

    def get_encoder_tilde_z_bottom(self):
        return self.se.buffer_tilde_z_bottom.clone()

    def get_encoders_z(self, reverse=True):
        return self.se.get_encoders_z(reverse)

    def decoder_bn_hat_z_layers(self, hat_z_layers, z_pre_layers):
        return self.de.bn_hat_z_layers(hat_z_layers, z_pre_layers)
コード例 #2
0
 def __init__(self, encoder_sizes, decoder_sizes, encoder_activations,
              encoder_train_bn_scaling, noise_std, use_cuda):
     super(Ladder, self).__init__()
     self.use_cuda = use_cuda
     decoder_in = encoder_sizes[-1]
     encoder_in = decoder_sizes[-1]
     self.se = StackedEncoders(encoder_in, encoder_sizes, encoder_activations,
                               encoder_train_bn_scaling, noise_std, use_cuda)
     self.de = StackedDecoders(decoder_in, decoder_sizes, encoder_in, use_cuda)
     self.bn_image = torch.nn.BatchNorm1d(encoder_in, affine=False)
コード例 #3
0
 def __init__(self, encoder_in, encoder_sizes, decoder_in, decoder_sizes,
              image_size, encoder_activations, encoder_train_bn_scaling,
              encoder_bias, noise_std):
     super(Ladder, self).__init__()
     self.se = StackedEncoders(encoder_in, encoder_sizes,
                               encoder_activations,
                               encoder_train_bn_scaling, encoder_bias,
                               noise_std)
     self.de = StackedDecoders(decoder_in, decoder_sizes, image_size)
     self.bn_image = torch.nn.BatchNorm1d(image_size, affine=False)