def _initialize(self): super(AutoEncoder, self).__init__() pre_params = self.init_args['pre_params'] enc_params = self.init_args['enc_params'] bn_params = self.init_args['bn_params'] dec_params = self.init_args['dec_params'] n_mel_chan = self.init_args['n_mel_chan'] training = self.init_args['training'] # the "preprocessing" self.preprocess = PreProcess(pre_params, n_quant=dec_params['n_quant']) self.encoder = enc.Encoder(n_in=n_mel_chan, parent_vc=None, **enc_params) bn_type = bn_params['type'] bn_extra = dict((k, v) for k, v in bn_params.items() if k != 'type') # In each case, the objective function's 'forward' method takes the # same arguments. if bn_type == 'vqvae': self.bottleneck = vq_bn.VQ(**bn_extra, n_in=enc_params['n_out']) self.objective = vq_bn.VQLoss(self.bottleneck) elif bn_type == 'vqvae-ema': self.bottleneck = vqema_bn.VQEMA(**bn_extra, n_in=enc_params['n_out'], training=training) self.objective = vqema_bn.VQEMALoss(self.bottleneck) elif bn_type == 'vae': # mu and sigma members self.bottleneck = vae_bn.VAE(n_in=enc_params['n_out'], n_out=bn_params['n_out']) self.objective = vae_bn.SGVBLoss(self.bottleneck, free_nats=bn_params['free_nats']) elif bn_type == 'ae': self.bottleneck = ae_bn.AE(n_out=bn_extra['n_out'], n_in=enc_params['n_out']) self.objective = ae_bn.AELoss(self.bottleneck, 0.001) else: raise InvalidArgument( 'bn_type must be one of "ae", "vae", or "vqvae"') self.bn_type = bn_type self.decoder = dec.WaveNet(**dec_params, parent_vc=self.encoder.vc['end'], n_lc_in=bn_params['n_out']) self.vc = self.decoder.vc self.decoder.post_init()
def _initialize(self): super(AutoEncoder, self).__init__() pre_params, enc_params, bn_params, dec_params, sam_per_slice = self.args # the "preprocessing" self.preprocess = PreProcess(pre_params, n_quant=dec_params['n_quant']) self.encoder = enc.Encoder(n_in=self.preprocess.mfcc.n_out, parent_rf=self.preprocess.rf, **enc_params) bn_type = bn_params['type'] bn_extra = dict((k, v) for k, v in bn_params.items() if k != 'type') # In each case, the objective function's 'forward' method takes the # same arguments. if bn_type == 'vqvae': self.bottleneck = vq_bn.VQ(**bn_extra, n_in=enc_params['n_out']) self.objective = vq_bn.VQLoss(self.bottleneck) elif bn_type == 'vae': # mu and sigma members self.bottleneck = vae_bn.VAE(**bn_extra, n_in=enc_params['n_out']) self.objective = vae_bn.SGVBLoss(self.bottleneck) elif bn_type == 'ae': self.bottleneck = ae_bn.AE(**bn_extra, n_in=enc_params['n_out']) self.objective = torch.nn.CrossEntropyLoss() else: raise InvalidArgument( 'bn_type must be one of "ae", "vae", or "vqvae"') self.bn_type = bn_type self.decoder = dec.WaveNet(**dec_params, parent_rf=self.encoder.rf, n_lc_in=bn_params['n_out']) self.rf = self.decoder.rf