def cost(X, Z_prior_mean, Z_prior_logvar, Z_mean, Z_logvar, X_mean, X_logvar, lengths): mask = T.arange(X.shape[0]).dimshuffle(0,'x')\ < lengths.dimshuffle('x',0) encoding_cost = mask * vae.kl_divergence( mean_1=Z_prior_mean, logvar_1=Z_prior_logvar, mean_2=Z_mean, logvar_2=Z_logvar ) reconstruction_cost = mask * vae.gaussian_nll(X, X_mean, X_logvar) return -T.sum(encoding_cost + reconstruction_cost)/T.sum(mask)
def cost(X, Z_prior_mean, Z_prior_std, Z_mean, Z_std, X_mean, X_std, lengths): mask = T.arange(X.shape[0]).dimshuffle(0,'x')\ < lengths.dimshuffle('x',0) encoding_cost = T.switch( mask, vae.kl_divergence( mean_1=Z_mean, std_1=Z_std, mean_2=Z_prior_mean, std_2=Z_prior_std, ), 0) reconstruction_cost = T.switch(mask, vae.gaussian_nll(X, X_mean, X_std), 0) return -T.sum(encoding_cost + reconstruction_cost) / T.sum(mask)
def cost(X, Z_prior_mean, Z_prior_std, Z_mean, Z_std, X_mean, X_std, lengths): mask = T.arange(X.shape[0]).dimshuffle(0,'x')\ < lengths.dimshuffle('x',0) encoding_cost = T.switch(mask, vae.kl_divergence( mean_1=Z_mean, std_1=Z_std, mean_2=Z_prior_mean, std_2=Z_prior_std, ), 0 ) reconstruction_cost = T.switch(mask, vae.gaussian_nll(X, X_mean, X_std), 0 ) return -T.sum(encoding_cost + reconstruction_cost)/T.sum(mask)