def extract_all_reconstructed_spectrograms(
        spectrograms_directory, checkpoint_directory=checkpoint_directory):
    extract_reconstructed_spectrograms(
        load_dataset.train_input_fn(BATCH_SIZE),
        spectrograms_directory + "/train",
        checkpoint_directory=checkpoint_directory)
    extract_reconstructed_spectrograms(
        load_dataset.train_input_fn(BATCH_SIZE),
        spectrograms_directory + "/test",
        checkpoint_directory=checkpoint_directory)
    extract_reconstructed_spectrograms(
        load_dataset.train_input_fn(BATCH_SIZE),
        spectrograms_directory + "/devel",
        checkpoint_directory=checkpoint_directory)
def extract_all_features(features_directory,
                         features_name,
                         checkpoint_directory=checkpoint_directory):
    extract_features(load_dataset.train_input_fn(BATCH_SIZE),
                     features_directory + '/' + features_name + ".train.csv",
                     checkpoint_directory=checkpoint_directory)
    extract_features(load_dataset.test_input_fn(BATCH_SIZE),
                     features_directory + '/' + features_name + ".test.csv",
                     checkpoint_directory=checkpoint_directory)
    extract_features(load_dataset.val_input_fn(BATCH_SIZE),
                     features_directory + '/' + features_name + ".devel.csv",
                     checkpoint_directory=checkpoint_directory)
def train_network(epochs, batch_size, learning_rate, num_units,
                  checkpoint_directory):
    global EPOCHS
    global BATCH_SIZE
    global LEARNING_RATE
    global NUM_UNITS
    global optimizer

    EPOCHS = epochs
    BATCH_SIZE = batch_size
    LEARNING_RATE = learning_rate
    NUM_UNITS = num_units

    global_step = tf.train.create_global_step()
    checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                     encoder=encoder,
                                     decoder=decoder,
                                     global_step=global_step)
    checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
    with writer.as_default():
        tf.contrib.summary.initialize()
        begin_time = time.time()

        for epoch in range(1, EPOCHS + 1):

            dataset = load_dataset.train_input_fn(BATCH_SIZE)
            epoch_time = now_time = time.time()
            total_loss = 0
            num_step = 0
            for (batch, (input, filenames, labels)) in enumerate(dataset):
                if input['image'].shape[0] < BATCH_SIZE:
                    break
                encoder_hidden = encoder.initialize_hidden_state()
                input['image'] = (input['image'] / 255 * 2) - 1
                batch_loss = train_step_network(input['image'], encoder_hidden)
                total_loss += batch_loss
                num_step = batch + 1
                global_step.assign_add(1)
                with tf.contrib.summary.record_summaries_every_n_global_steps(
                        1):
                    tf.contrib.summary.scalar('loss_attention', batch_loss)
                now_time = time.time()
                print('Epoch {} Step {} Loss {:.4f} Elapsed time {}'.format(
                    epoch, num_step, batch_loss.numpy(),
                    time.strftime("%H:%M:%S",
                                  time.gmtime(now_time - begin_time))))
            print('Epoch {} Loss {:.4f} Duration {}'.format(
                epoch, total_loss / num_step,
                time.strftime("%H:%M:%S", time.gmtime(now_time - epoch_time))))
            if epoch % 10 == 0:
                checkpoint.save(file_prefix=checkpoint_prefix)
tf.enable_eager_execution()

EPOCHS = 50
BATCH_SIZE = 64
LEARNING_RATE = 0.001
NUM_UNITS = 256

logdir = "./tensorboard_logs/attention_bi_bi_hidden"
writer = tf.contrib.summary.create_file_writer(logdir)
with writer.as_default():
    tf.contrib.summary.always_record_summaries()
optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE)
checkpoint_directory = './training_checkpoints/attention_bi_bi_hidden'
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
dataset = load_dataset.train_input_fn(BATCH_SIZE)


class Encoder(tf.keras.Model):
    def __init__(self, num_units, batch_size):
        super(Encoder, self).__init__()
        self.batch_size = batch_size
        self.num_units = num_units
        self.gru_1_f = tf.keras.layers.GRU(
            self.num_units,
            return_sequences=True,
            recurrent_initializer='glorot_uniform',
            return_state=True)
        self.gru_1_b = tf.keras.layers.GRU(
            self.num_units,
            return_sequences=True,