def binary_crossentropy(y_true, y_pred): return K.mean(K.binary_crossentropy(y_true, y_pred), axis=-1)
parser.add_argument("-w", "--weights", help=help_) help_ = "Use mse loss instead of binary cross entropy (default)" parser.add_argument("-m", "--mse", help=help_, action='store_true') args = parser.parse_args() vae, encoder, decoder, inputs, z_mean, z_log_var, latent_inputs, outputs = get_model( input_shape, intermediate_dim, latent_dim) models = (encoder, decoder) data = (x_test, y_test) # VAE loss = mse_loss or xent_loss + kl_loss # inputs = K.flatten(inputs) # outputs = K.flatten(outputs) if args.mse: reconstruction_loss = mse(inputs, outputs) else: reconstruction_loss = binary_crossentropy(inputs, outputs) reconstruction_loss *= image_size * image_size kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var) kl_loss = K.sum(kl_loss, axis=-1) kl_loss *= -0.5 vae_loss = K.mean(reconstruction_loss + kl_loss) vae.add_loss(vae_loss) vae.compile(optimizer='adam') vae.summary() #plot_model(vae, # to_file='vae_mlp.png', # show_shapes=True) train_data_gen = data_gen() test_data_gen = data_gen('test')