def get_baseline_batch(self, hparams): """Get the Tensor expressions from the reader. Args: hparams: Hyperparameters object with specgram parameters. Returns: A dict of key:tensor pairs. This includes "pitch", "wav", and "key". """ example = self.get_example(hparams.batch_size) audio = tf.slice(example["audio"], [0], [64000]) audio = tf.reshape(audio, [1, 64000]) pitch = tf.slice(example["pitch"], [0], [1]) velocity = tf.slice(example["velocity"], [0], [1]) instrument_source = tf.slice(example["instrument_source"], [0], [1]) instrument_family = tf.slice(example["instrument_family"], [0], [1]) qualities = tf.slice(example["qualities"], [0], [10]) qualities = tf.reshape(qualities, [1, 10]) # Get Specgrams hop_length = hparams.hop_length n_fft = hparams.n_fft if hop_length and n_fft: specgram = utils.tf_specgram( audio, n_fft=n_fft, hop_length=hop_length, mask=hparams.mask, log_mag=hparams.log_mag, re_im=hparams.re_im, dphase=hparams.dphase, mag_only=hparams.mag_only) shape = [1] + SPECGRAM_REGISTRY[(n_fft, hop_length)] if hparams.mag_only: shape[-1] = 1 specgram = tf.reshape(specgram, shape) tf.logging.info("SPECGRAM BEFORE PADDING", specgram) if hparams.pad: # Pad and crop specgram to 256x256 num_padding = 2**int(np.ceil(np.log(shape[2]) / np.log(2))) - shape[2] tf.logging.info("num_pading: %d" % num_padding) specgram = tf.reshape(specgram, shape) specgram = tf.pad(specgram, [[0, 0], [0, 0], [0, num_padding], [0, 0]]) specgram = tf.slice(specgram, [0, 0, 0, 0], [-1, shape[1] - 1, -1, -1]) tf.logging.info("SPECGRAM AFTER PADDING", specgram) # Form a Batch if self.is_training: (audio, velocity, pitch, specgram, instrument_source, instrument_family, qualities) = tf.train.shuffle_batch( [ audio, velocity, pitch, specgram, instrument_source, instrument_family, qualities ], batch_size=hparams.batch_size, capacity=20 * hparams.batch_size, min_after_dequeue=10 * hparams.batch_size, enqueue_many=True) elif hparams.batch_size > 1: (audio, velocity, pitch, specgram, instrument_source, instrument_family, qualities) = tf.train.batch( [ audio, velocity, pitch, specgram, instrument_source, instrument_family, qualities ], batch_size=hparams.batch_size, capacity=10 * hparams.batch_size, enqueue_many=True) audio.set_shape([hparams.batch_size, 64000]) batch = dict( pitch=pitch, velocity=velocity, audio=audio, instrument_source=instrument_source, instrument_family=instrument_family, qualities=qualities, spectrogram=specgram) return batch