def do_experiment(model_config):

    tf.reset_default_graph()
    experiment_id = ex.current_run._id
    print('Experiment ID: {eid}'.format(eid=experiment_id))

    # Prepare data
    print('Preparing dataset')
    train_data, val_data, test_data = Dataset.prepare_datasets(model_config)

    # Start session
    tf_config = tf.ConfigProto()
    tf_config.gpu_options.allow_growth = True
    sess = tf.Session(config=tf_config)

    # Create iterators
    handle = tf.placeholder(tf.string, shape=[])
    iterator = tf.data.Iterator.from_string_handle(handle,
                                                   train_data.output_types,
                                                   train_data.output_shapes)
    mixed_spec, voice_spec, mixed_audio, voice_audio = iterator.get_next()

    training_iterator = train_data.make_initializable_iterator()
    validation_iterator = val_data.make_initializable_iterator()
    testing_iterator = test_data.make_initializable_iterator()

    training_handle = sess.run(training_iterator.string_handle())
    validation_handle = sess.run(validation_iterator.string_handle())
    testing_handle = sess.run(testing_iterator.string_handle())

    # Create variable placeholders
    is_training = tf.placeholder(shape=(), dtype=bool)
    mixed_mag = tf.expand_dims(mixed_spec[:, :, 1:, 0], 3)
    mixed_phase = tf.expand_dims(mixed_spec[:, :, 1:, 1], 3)
    voice_mag = tf.expand_dims(voice_spec[:, :, 1:, 0], 3)

    # Build U-Net model
    print('Creating model')
    model = UNet.UNetModel(mixed_mag,
                           voice_mag,
                           mixed_phase,
                           mixed_audio,
                           voice_audio,
                           'unet',
                           is_training,
                           name='U_Net_Model')

    if model_config['loading']:
        # TODO - Think this works now but needs proper testing
        print('Loading checkpoint')
        checkpoint = os.path.join(model_config['model_base_dir'],
                                  model_config['checkpoint_to_load'])
        restorer = tf.train.Saver()
        restorer.restore(sess, checkpoint)

    # Summaries
    model_folder = str(experiment_id)
    writer = tf.summary.FileWriter(os.path.join(model_config["log_dir"],
                                                model_folder),
                                   graph=sess.graph)

    # Get baseline metrics at initialisation
    sess.run(tf.global_variables_initializer())
    test_count = 0
    if model_config['INITIALISATION_TEST']:
        print('Running initialisation test')
        initial_test_loss, test_count = test(sess, model, model_config, handle,
                                             testing_iterator, testing_handle,
                                             writer, test_count)

    # Train the model
    model = train(sess, model, model_config, model_folder, handle,
                  training_iterator, training_handle, validation_iterator,
                  validation_handle, writer)

    # Test trained model
    mean_test_loss, test_count = test(sess, model, model_config, handle,
                                      testing_iterator, testing_handle, writer,
                                      test_count)
    print('{ts}:\n\tAll done!'.format(ts=datetime.datetime.now()))
    if model_config['INITIALISATION_TEST']:
        print('\tInitial test loss: {init}'.format(init=initial_test_loss))
    print('\tFinal test loss: {final}'.format(final=mean_test_loss))
def do_experiment(model_config):

    tf.reset_default_graph()
    experiment_id = ex.current_run._id
    print('Experiment ID: {eid}'.format(eid=experiment_id))

    # Prepare data
    print('Preparing dataset')
    train_data, val_data, test_data = Dataset.prepare_datasets(model_config)

    # Start session
    sess = tf.Session()

    # Create iterators
    handle = tf.placeholder(tf.string, shape=[])
    iterator = tf.data.Iterator.from_string_handle(handle, train_data.output_types, train_data.output_shapes)
    mixed, voice = iterator.get_next()

    training_iterator = train_data.make_initializable_iterator()
    validation_iterator = val_data.make_initializable_iterator()
    testing_iterator = test_data.make_initializable_iterator()

    training_handle = sess.run(training_iterator.string_handle())
    validation_handle = sess.run(validation_iterator.string_handle())
    testing_handle = sess.run(testing_iterator.string_handle())

    # Create variable placeholders
    is_training = tf.placeholder(shape=(), dtype=bool)
    mixed_mag = tf.expand_dims(mixed[0][:, :, 1:, 0], 3)  # Yet more hacking to get around this tuple problem
    mixed_phase = tf.expand_dims(mixed[0][:, :, 1:, 1], 3)
    voice_mag = tf.expand_dims(voice[0][:, :, 1:, 0], 3)

    # Build U-Net model
    print('Creating model')
    model = UNet.UNetModel(mixed_mag, voice_mag, mixed_phase, is_training, name='U_Net_Model')

    if model_config['loading']:
        # TODO - Think this works now but needs proper testing
        print('Loading checkpoint')
        checkpoint = os.path.join(model_config['model_base_dir'], model_config['checkpoint_to_load'])
        restorer = tf.train.Saver()
        restorer.restore(sess, checkpoint)

    # Summaries
    model_folder = str(experiment_id)
    writer = tf.summary.FileWriter(os.path.join(model_config["log_dir"], model_folder), graph=sess.graph)

    # Get baseline metrics at initialisation
    sess.run(tf.global_variables_initializer())
    print('Running initialisation test')
    test_count = 1
    initial_test_loss = test(sess, model, model_config, handle, testing_iterator, testing_handle, writer, test_count)
    test_count += 1

    # Train the model
    model = train(sess, model, model_config, model_folder, handle, training_iterator, training_handle,
                  validation_iterator, validation_handle, writer)

    # Test trained model
    mean_test_loss = test(sess, model, model_config, handle, testing_iterator, testing_handle, writer, test_count)
    print('All done!\nInitial test loss: {init}\nFinal test loss: {final}'
          .format(init=initial_test_loss, final=mean_test_loss))