示例#1
0
        dataset_iter = tfe.Iterator(train_dataset)
        x_batch = dataset_iter.next()

    if opt_type is not 'adam':
        lr.assign(
            tf.train.polynomial_decay(lr_rbm,
                                      step,
                                      decay_steps,
                                      lr_end,
                                      power=0.5))

    # update the variables following gradients info
    with tf.GradientTape() as rbm_tape:
        batch_loss_rbm = loss_rbm(x_batch)
        average_loss_rbm += batch_loss_rbm
    grad_rbm = rbm_tape.gradient(batch_loss_rbm, rbm.params())
    optimizer_rbm.apply_gradients(zip(grad_rbm, rbm.params()))

    for j in xrange(dec_per_rbm):
        with tf.GradientTape() as dec_tape:
            batch_loss_dec, batch_energy_neg, batch_ent_dec = loss_dec()
            average_loss_dec += batch_loss_dec
            average_energy_neg += batch_energy_neg
            average_ent_dec += batch_ent_dec
        grad_dec = dec_tape.gradient(batch_loss_dec,
                                     dec.params() + enc_h2z.params())
        optimizer_dec.apply_gradients(
            zip(grad_dec,
                dec.params() + enc_h2z.params()))

    for j in xrange(enc_per_rbm):
def evaluate_ais_ll(ds, data_train):
    _, ais_log_z, _ = estimate_log_partition_function(data_train, rbm)
    ds = ds.batch(evaluation_batch_size)
    ds_iter = tfe.Iterator(ds)
    ais_ll = []
    for batch in ds_iter:
        ais_ll.append(
            tf.reduce_mean(likelihood_ais(batch, rbm, ais_log_z)).numpy())
    return ais_log_z.numpy(), np.mean(ais_ll)


'''
train loop
'''
saver = tf.contrib.eager.Saver(rbm.params())
average_loss = 0.
start = time.time()
ais_ll_list, ais_log_z_list = [], []
best_valid_ll, best_test_ll = -np.inf, -np.inf
for step in range(1, num_steps + 1):

    # after one epoch, shuffle and refill the quene
    try:
        x_batch = dataset_iter.next()
    except StopIteration:
        dataset_iter = tfe.Iterator(train_dataset)
        x_batch = dataset_iter.next()

    # anneal learning rate if necessary
    if step % anneal_steps == 0: