コード例 #1
0
ファイル: VaDE.py プロジェクト: baohq1595/VaDE
z_log_var = Dense(latent_dim)(h)
z = Lambda(sampling, output_shape=(latent_dim, ))([z_mean, z_log_var])
h_decoded = Dense(intermediate_dim[-1], activation='relu')(z)
h_decoded = Dense(intermediate_dim[-2], activation='relu')(h_decoded)
h_decoded = Dense(intermediate_dim[-3], activation='relu')(h_decoded)
x_decoded_mean = Dense(original_dim, activation=datatype)(h_decoded)

#========================
Gamma = Lambda(get_gamma, output_shape=(n_centroid, ))(z)
sample_output = Model(x, z_mean)
gamma_output = Model(x, Gamma)
#===========================================
vade = Model(x, x_decoded_mean)
if ispretrain == True:
    vade = load_pretrain_weights(vade, dataset)
adam_nn = Adam(lr=lr_nn, epsilon=1e-4)
adam_gmm = Adam(lr=lr_gmm, epsilon=1e-4)
vade._trainable_weights = [theta_p, u_p, lambda_p]
# vade.optimizer = [adam_nn, adam_gmm]
# add_optimizer=adam_gmm
vade.compile(optimizer=adam_gmm, loss=vae_loss)
epoch_begin = EpochBegin()
#-------------------------------------------------------

vade.fit(X,
         X,
         shuffle=True,
         nb_epoch=epoch,
         batch_size=batch_size,
         callbacks=[epoch_begin])