Exemplo n.º 1
0
def main(_):
    tf.logging.set_verbosity(tf.logging.INFO)
    # Define the input function for training
    wav_files, labels, label_to_id = get_file_and_labels(
        os.path.join(FLAGS.data_dir, 'train_labels'))
    wav_files = [
        os.path.join(FLAGS.data_dir, 'train', wav_file)
        for wav_file in wav_files
    ]

    train_num_classes = len(label_to_id)
    if FLAGS.num_gpus > 1:
        # MirroredStrategy: This does in-graph replication with synchronous training on many GPUs on one machine.
        # Essentially, we create copies of all variables in the model's layers on each device.
        # We then use all-reduce to combine gradients across the devices
        # before applying them to the variables to keep them in sync.
        # Reference: https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/distribute
        train_distribute = tf.contrib.distribute.MirroredStrategy()
        run_config = tf.estimator.RunConfig(train_distribute=train_distribute)
    else:
        run_config = None

    model = create_model(config=run_config,
                         model_dir=FLAGS.model_dir,
                         params={
                             'num_classes': train_num_classes,
                             **FLAGS.__dict__
                         })
    train_input_fn = lambda: get_input_function(
        wav_files=wav_files, labels=labels, is_training=True, **FLAGS.__dict__)
    model.train(train_input_fn, steps=FLAGS.num_steps)
    def test_get_input_function(self):
        wav_files = ['../data/train/121624931534904112937-0.wav',
                     '../data/train/121624931534904112937-1.wav',
                     '../data/train/121624931534904112937-2.wav'
                     ]
        labels = [0, 1, 2]
        desired_ms = 100
        window_size_ms = 25
        window_stride_ms = 10
        batch_size = 2
        features, label_ids = get_input_function(wav_files,
                                                 labels,
                                                 batch_size=batch_size,
                                                 desired_ms=desired_ms,
                                                 window_size_ms=window_size_ms,
                                                 window_stride_ms=window_stride_ms,
                                                 magnitude_squared=True,
                                                 input_feature_dim=40,
                                                 input_feature='fbank',
                                                 is_training=True)
        labels_readout = []
        repeated_times = 10
        with tf.Session() as sess:
            for i in range(repeated_times):
                label_ids_val = sess.run(label_ids)
                labels_readout.extend(label_ids_val)

        self.assertEqual(len(labels_readout), repeated_times * batch_size, 'total number of labels')
        self.assertEqual(list(dict.fromkeys(labels_readout)).sort(), list(dict.fromkeys(labels)).sort(),
                         'the same unique labels')
Exemplo n.º 3
0
def get_embeddings(model, wav_files, **kwargs):
    label_ids = [-1 for _ in wav_files]

    predict_input_fn = lambda: get_input_function(
        wav_files, label_ids, is_training=False, **kwargs)

    embeddings = []
    for prediction in model.predict(predict_input_fn,
                                    yield_single_examples=False):
        embeddings.extend(prediction['embeddings'])

    return embeddings