Exemplo n.º 1
0
    def _initialize(self):
        super(AutoEncoder, self).__init__()
        pre_params = self.init_args['pre_params']
        enc_params = self.init_args['enc_params']
        bn_params = self.init_args['bn_params']
        dec_params = self.init_args['dec_params']
        n_mel_chan = self.init_args['n_mel_chan']
        training = self.init_args['training']

        # the "preprocessing"
        self.preprocess = PreProcess(pre_params, n_quant=dec_params['n_quant'])

        self.encoder = enc.Encoder(n_in=n_mel_chan,
                                   parent_vc=None,
                                   **enc_params)

        bn_type = bn_params['type']
        bn_extra = dict((k, v) for k, v in bn_params.items() if k != 'type')

        # In each case, the objective function's 'forward' method takes the
        # same arguments.
        if bn_type == 'vqvae':
            self.bottleneck = vq_bn.VQ(**bn_extra, n_in=enc_params['n_out'])
            self.objective = vq_bn.VQLoss(self.bottleneck)

        elif bn_type == 'vqvae-ema':
            self.bottleneck = vqema_bn.VQEMA(**bn_extra,
                                             n_in=enc_params['n_out'],
                                             training=training)
            self.objective = vqema_bn.VQEMALoss(self.bottleneck)

        elif bn_type == 'vae':
            # mu and sigma members
            self.bottleneck = vae_bn.VAE(n_in=enc_params['n_out'],
                                         n_out=bn_params['n_out'])
            self.objective = vae_bn.SGVBLoss(self.bottleneck,
                                             free_nats=bn_params['free_nats'])

        elif bn_type == 'ae':
            self.bottleneck = ae_bn.AE(n_out=bn_extra['n_out'],
                                       n_in=enc_params['n_out'])
            self.objective = ae_bn.AELoss(self.bottleneck, 0.001)

        else:
            raise InvalidArgument(
                'bn_type must be one of "ae", "vae", or "vqvae"')

        self.bn_type = bn_type
        self.decoder = dec.WaveNet(**dec_params,
                                   parent_vc=self.encoder.vc['end'],
                                   n_lc_in=bn_params['n_out'])
        self.vc = self.decoder.vc
        self.decoder.post_init()
Exemplo n.º 2
0
    def _initialize(self):
        super(AutoEncoder, self).__init__()
        pre_params, enc_params, bn_params, dec_params, sam_per_slice = self.args

        # the "preprocessing"
        self.preprocess = PreProcess(pre_params, n_quant=dec_params['n_quant'])

        self.encoder = enc.Encoder(n_in=self.preprocess.mfcc.n_out,
                                   parent_rf=self.preprocess.rf,
                                   **enc_params)

        bn_type = bn_params['type']
        bn_extra = dict((k, v) for k, v in bn_params.items() if k != 'type')

        # In each case, the objective function's 'forward' method takes the
        # same arguments.
        if bn_type == 'vqvae':
            self.bottleneck = vq_bn.VQ(**bn_extra, n_in=enc_params['n_out'])
            self.objective = vq_bn.VQLoss(self.bottleneck)

        elif bn_type == 'vae':
            # mu and sigma members
            self.bottleneck = vae_bn.VAE(**bn_extra, n_in=enc_params['n_out'])
            self.objective = vae_bn.SGVBLoss(self.bottleneck)

        elif bn_type == 'ae':
            self.bottleneck = ae_bn.AE(**bn_extra, n_in=enc_params['n_out'])
            self.objective = torch.nn.CrossEntropyLoss()

        else:
            raise InvalidArgument(
                'bn_type must be one of "ae", "vae", or "vqvae"')

        self.bn_type = bn_type
        self.decoder = dec.WaveNet(**dec_params,
                                   parent_rf=self.encoder.rf,
                                   n_lc_in=bn_params['n_out'])
        self.rf = self.decoder.rf