Ejemplo n.º 1
0
            logits=D_logit_fake, labels=tf.ones_like(D_logit_fake)))

    D_solver = tf.train.GradientDescentOptimizer(learning_rate=lr).minimize(
        D_loss, var_list=discriminator_vars)
    G_solver = tf.train.AdamOptimizer().minimize(G_loss,
                                                 var_list=generator_vars)

    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    # plot the ouput from the same seed
    vis_z = sample_Z(batch_size, seq_length, latent_dim, use_time=use_time)
    X_mb_vis, Y_mb_vis = get_batch(train_seqs, train_targets, batch_size, 0)
    vis_sample = sess.run(G_sample, feed_dict={Z: vis_z, CG: Y_mb_vis})
    plotting.vis_eICU_patients_downsampled(vis_sample,
                                           seq_length,
                                           identifier=identifier,
                                           idx=0)

    # visualise some real samples
    vis_real = np.float32(
        vali_seqs[np.random.choice(len(vali_seqs), size=batch_size), :, :])
    plotting.vis_eICU_patients_downsampled(vis_real,
                                           seq_length,
                                           identifier=identifier + '_real',
                                           idx=0)

    trace = open('./experiments/traces/' + identifier + '.trace.txt', 'w')
    trace.write('epoch D_loss G_loss time\n')
    print('epoch\tD_loss\tG_loss\ttime\n')
    t0 = time.time()
Ejemplo n.º 2
0
        samps = vis_real
    if multivariate_mnist:
        plotting.save_mnist_plot_sample(samps.reshape(-1, seq_length**2, 1),
                                        0,
                                        identifier + '_real',
                                        n_samples=6,
                                        labels=labs)
    else:
        plotting.save_mnist_plot_sample(samps,
                                        0,
                                        identifier + '_real',
                                        n_samples=6,
                                        labels=labs)
elif 'eICU' in data:
    plotting.vis_eICU_patients_downsampled(vis_real,
                                           resample_rate_in_min,
                                           identifier=identifier + '_real',
                                           idx=0)
else:
    plotting.save_plot_sample(vis_real,
                              0,
                              identifier + '_real',
                              n_samples=6,
                              num_epochs=num_epochs)

trace = open('./experiments/traces/' + identifier + '.trace.txt', 'w')
trace.write('epoch time D_loss G_loss mmd2 that ll real_ll\n')

# --- train --- #
train_vars = [
    'batch_size', 'D_rounds', 'G_rounds', 'use_time', 'seq_length',
    'latent_dim', 'num_generated_features', 'cond_dim', 'max_val', 'WGAN_clip',