def do_experiment(model_config):

    tf.reset_default_graph()
    experiment_id = ex.current_run._id
    print('Experiment ID: {eid}'.format(eid=experiment_id))

    # Prepare data
    print('Preparing dataset')
    train_data, val_data, test_data = dataset.prepare_datasets(model_config)
    print('Dataset ready')

    # Start session
    tf_config = tf.ConfigProto()
    #tf_config.gpu_options.allow_growth = True
    tf_config.gpu_options.visible_device_list = str(model_config['GPU'])
    sess = tf.Session(config=tf_config)
    #sess = tf.Session()
    #sess = tf_debug.LocalCLIDebugWrapperSession(sess, ui_type="readline")

    print('Session started')

    # Create iterators
    handle = tf.placeholder(tf.string, shape=[])
    iterator = tf.data.Iterator.from_string_handle(handle,
                                                   train_data.output_types,
                                                   train_data.output_shapes)
    mixed_spec, voice_spec, background_spec, mixed_audio, voice_audio, background_audio = iterator.get_next(
    )

    training_iterator = train_data.make_initializable_iterator()
    validation_iterator = val_data.make_initializable_iterator()
    testing_iterator = test_data.make_initializable_iterator()

    training_handle = sess.run(training_iterator.string_handle())
    validation_handle = sess.run(validation_iterator.string_handle())
    testing_handle = sess.run(testing_iterator.string_handle())
    print('Iterators created')

    # Create variable placeholders and model
    is_training = tf.placeholder(shape=(), dtype=bool)
    mixed_phase = tf.expand_dims(mixed_spec[:, :, :-1, 3], 3)

    print('Creating model')
    if model_config['data_type'] == 'mag':
        mixed_input = tf.expand_dims(mixed_spec[:, :, :-1, 2], 3)
        voice_input = tf.expand_dims(voice_spec[:, :, :-1, 2], 3)
    elif model_config['data_type'] in ['mag_phase', 'mag_phase_diff']:
        mixed_input = mixed_spec[:, :, :-1, 2:4]
        voice_input = voice_spec[:, :, :-1, 2:4]
    elif model_config['data_type'] == 'real_imag':
        mixed_input = mixed_spec[:, :, :-1, 0:2]
        voice_input = voice_spec[:, :, :-1, 0:2]
    elif model_config['data_type'] in ['mag_real_imag', 'mag_phase2']:
        mixed_input = tf.concat([
            tf.expand_dims(mixed_spec[:, :, :-1, 2], 3), mixed_spec[:, :, :-1,
                                                                    0:2]
        ], 3)
        voice_input = tf.concat([
            tf.expand_dims(voice_spec[:, :, :-1, 2], 3), voice_spec[:, :, :-1,
                                                                    0:2]
        ], 3)
    elif model_config['data_type'] == 'mag_phase_real_imag':
        mixed_input = mixed_spec[:, :, :-1, :]
        voice_input = voice_spec[:, :, :-1, :]

    model = audio_models.MagnitudeModel(mixed_input,
                                        voice_input,
                                        mixed_phase,
                                        mixed_audio,
                                        voice_audio,
                                        background_audio,
                                        model_config['model_variant'],
                                        is_training,
                                        model_config['learning_rate'],
                                        model_config['data_type'],
                                        model_config['phase_weight'],
                                        name='Magnitude_Model')

    sess.run(tf.global_variables_initializer())

    if model_config['loading']:
        print('Loading checkpoint')
        checkpoint = os.path.join(model_config['model_base_dir'],
                                  model_config['checkpoint_to_load'])
        restorer = tf.train.Saver()
        restorer.restore(sess, checkpoint)

    # Summaries
    model_folder = str(experiment_id)
    writer = tf.summary.FileWriter(os.path.join(model_config["log_dir"],
                                                model_folder),
                                   graph=sess.graph)

    # Get baseline metrics at initialisation
    test_count = 0
    if model_config['initialisation_test']:
        print('Running initialisation test')
        initial_test_loss, test_count = test(sess, model, model_config, handle,
                                             testing_iterator, testing_handle,
                                             test_count, experiment_id)

    # Train the model
    model = train(sess, model, model_config, model_folder, handle,
                  training_iterator, training_handle, validation_iterator,
                  validation_handle, writer)

    # Test trained model
    mean_test_loss, test_count = test(sess, model, model_config, handle,
                                      testing_iterator, testing_handle,
                                      test_count, experiment_id)
    print('{ts}:\n\tAll done with experiment {exid}!'.format(
        ts=datetime.datetime.now(), exid=experiment_id))
    if model_config['initialisation_test']:
        print('\tInitial test loss: {init}'.format(init=initial_test_loss))
    print('\tFinal test loss: {final}'.format(final=mean_test_loss))
Example #2
0
def do_experiment(model_config):

    tf.reset_default_graph()
    experiment_id = ex.current_run._id
    print('Experiment ID: {eid}'.format(eid=experiment_id))

    # Prepare data
    print('Preparing dataset')
    train_data, val_data, test_data = dataset.prepare_datasets(model_config)
    print('Dataset ready')

    # Start session
    #tf_config = tf.ConfigProto()
    #tf_config.gpu_options.allow_growth = True
    #tf_config.gpu_options.visible_device_list = "0"
    #sess = tf.Session(config=tf_config)
    sess = tf.Session()

    print('Session started')

    # Create iterators
    handle = tf.placeholder(tf.string, shape=[])
    iterator = tf.data.Iterator.from_string_handle(handle, train_data.output_types, train_data.output_shapes)
    mixed_spec, voice_spec, mixed_audio, voice_audio = iterator.get_next()

    training_iterator = train_data.make_initializable_iterator()
    validation_iterator = val_data.make_initializable_iterator()
    testing_iterator = test_data.make_initializable_iterator()

    training_handle = sess.run(training_iterator.string_handle())
    validation_handle = sess.run(validation_iterator.string_handle())
    testing_handle = sess.run(testing_iterator.string_handle())
    print('Iterators created')

    # Create variable placeholders and model
    is_training = tf.placeholder(shape=(), dtype=bool)
    if model_config['mag_phase']:
        mixed_mag = tf.expand_dims(mixed_spec[:, :, :-1, 0], 3)
        mixed_phase = tf.expand_dims(mixed_spec[:, :, :-1, 1], 3)
        voice_mag = tf.expand_dims(voice_spec[:, :, :-1, 0], 3)

        print('Creating model')
        model = audio_models.MagnitudeModel(mixed_mag, voice_mag, mixed_phase, mixed_audio, voice_audio,
                                            model_config['model_variant'], is_training, name='U_Net_Model')
    else:
        mixed_spec_trim = mixed_spec[:, :, :-1, :]
        voice_spec_trim = voice_spec[:, :, :-1, :]

        print('Creating model')
        model = audio_models.ComplexNumberModel(mixed_spec_trim, voice_spec_trim, mixed_audio, voice_audio,
                                                model_config['model_variant'], is_training)

    sess.run(tf.global_variables_initializer())

    if model_config['loading']:
        # TODO - Think this works now but needs proper testing
        print('Loading checkpoint')
        checkpoint = os.path.join(model_config['model_base_dir'], model_config['checkpoint_to_load'])
        restorer = tf.train.Saver()
        restorer.restore(sess, checkpoint)

    # Summaries
    model_folder = str(experiment_id)
    writer = tf.summary.FileWriter(os.path.join(model_config["log_dir"], model_folder), graph=sess.graph)

    # Get baseline metrics at initialisation
    test_count = 0
    if model_config['initialisation_test']:
        print('Running initialisation test')
        initial_test_loss, test_count = test(sess, model, model_config, handle, testing_iterator, testing_handle,
                                             writer, test_count, experiment_id)

    # Train the model
    model = train(sess, model, model_config, model_folder, handle, training_iterator, training_handle,
                  validation_iterator, validation_handle, writer)

    # Test trained model
    mean_test_loss, test_count = test(sess, model, model_config, handle, testing_iterator, testing_handle, writer,
                                      test_count, experiment_id)
    print('{ts}:\n\tAll done with experiment {exid}!'.format(ts=datetime.datetime.now(), exid=experiment_id))
    if model_config['initialisation_test']:
        print('\tInitial test loss: {init}'.format(init=initial_test_loss))
    print('\tFinal test loss: {final}'.format(final=mean_test_loss))