def main(_): tf.logging.set_verbosity(tf.logging.INFO) # Define the input function for training wav_files, labels, label_to_id = get_file_and_labels( os.path.join(FLAGS.data_dir, 'train_labels')) wav_files = [ os.path.join(FLAGS.data_dir, 'train', wav_file) for wav_file in wav_files ] train_num_classes = len(label_to_id) if FLAGS.num_gpus > 1: # MirroredStrategy: This does in-graph replication with synchronous training on many GPUs on one machine. # Essentially, we create copies of all variables in the model's layers on each device. # We then use all-reduce to combine gradients across the devices # before applying them to the variables to keep them in sync. # Reference: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/distribute train_distribute = tf.contrib.distribute.MirroredStrategy() run_config = tf.estimator.RunConfig(train_distribute=train_distribute) else: run_config = None model = create_model(config=run_config, model_dir=FLAGS.model_dir, params={ 'num_classes': train_num_classes, **FLAGS.__dict__ }) train_input_fn = lambda: get_input_function( wav_files=wav_files, labels=labels, is_training=True, **FLAGS.__dict__) model.train(train_input_fn, steps=FLAGS.num_steps)
def test_get_input_function(self): wav_files = ['../data/train/121624931534904112937-0.wav', '../data/train/121624931534904112937-1.wav', '../data/train/121624931534904112937-2.wav' ] labels = [0, 1, 2] desired_ms = 100 window_size_ms = 25 window_stride_ms = 10 batch_size = 2 features, label_ids = get_input_function(wav_files, labels, batch_size=batch_size, desired_ms=desired_ms, window_size_ms=window_size_ms, window_stride_ms=window_stride_ms, magnitude_squared=True, input_feature_dim=40, input_feature='fbank', is_training=True) labels_readout = [] repeated_times = 10 with tf.Session() as sess: for i in range(repeated_times): label_ids_val = sess.run(label_ids) labels_readout.extend(label_ids_val) self.assertEqual(len(labels_readout), repeated_times * batch_size, 'total number of labels') self.assertEqual(list(dict.fromkeys(labels_readout)).sort(), list(dict.fromkeys(labels)).sort(), 'the same unique labels')
def get_embeddings(model, wav_files, **kwargs): label_ids = [-1 for _ in wav_files] predict_input_fn = lambda: get_input_function( wav_files, label_ids, is_training=False, **kwargs) embeddings = [] for prediction in model.predict(predict_input_fn, yield_single_examples=False): embeddings.extend(prediction['embeddings']) return embeddings